wip: users, tools

This commit is contained in:
Timothy Jaeryang Baek 2025-08-14 16:19:54 +04:00
parent 44e9ae243d
commit e1da74541b
32 changed files with 289 additions and 248 deletions

View file

@ -174,7 +174,7 @@ async def generate_function_chat_completion(
pipe_id, _ = pipe_id.split(".", 1) pipe_id, _ = pipe_id.split(".", 1)
return pipe_id return pipe_id
def get_function_params(function_module, form_data, user, extra_params=None): async def get_function_params(function_module, form_data, user, extra_params=None):
if extra_params is None: if extra_params is None:
extra_params = {} extra_params = {}
@ -187,7 +187,9 @@ async def generate_function_chat_completion(
} }
if "__user__" in params and hasattr(function_module, "UserValves"): if "__user__" in params and hasattr(function_module, "UserValves"):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) user_valves = await Functions.get_user_valves_by_id_and_user_id(
pipe_id, user.id
)
try: try:
params["__user__"]["valves"] = function_module.UserValves(**user_valves) params["__user__"]["valves"] = function_module.UserValves(**user_valves)
except Exception as e: except Exception as e:
@ -232,7 +234,7 @@ async def generate_function_chat_completion(
"__metadata__": metadata, "__metadata__": metadata,
"__request__": request, "__request__": request,
} }
extra_params["__tools__"] = get_tools( extra_params["__tools__"] = await get_tools(
request, request,
tool_ids, tool_ids,
user, user,
@ -261,7 +263,7 @@ async def generate_function_chat_completion(
function_module = get_function_module_by_id(request, pipe_id) function_module = get_function_module_by_id(request, pipe_id)
pipe = function_module.pipe pipe = function_module.pipe
params = get_function_params(function_module, form_data, user, extra_params) params = await get_function_params(function_module, form_data, user, extra_params)
if form_data.get("stream", False): if form_data.get("stream", False):

View file

@ -530,7 +530,7 @@ async def lifespan(app: FastAPI):
# This should be blocking (sync) so functions are not deactivated on first /get_models calls # This should be blocking (sync) so functions are not deactivated on first /get_models calls
# when the first user lands on the / route. # when the first user lands on the / route.
log.info("Installing external dependencies of functions and tools...") log.info("Installing external dependencies of functions and tools...")
install_tool_and_function_dependencies() await install_tool_and_function_dependencies()
app.state.redis = get_redis_connection( app.state.redis = get_redis_connection(
redis_url=REDIS_URL, redis_url=REDIS_URL,
@ -1267,11 +1267,11 @@ if audit_level != AuditLevel.NONE:
async def get_models( async def get_models(
request: Request, refresh: bool = False, user=Depends(get_verified_user) request: Request, refresh: bool = False, user=Depends(get_verified_user)
): ):
def get_filtered_models(models, user): async def get_filtered_models(models, user):
filtered_models = [] filtered_models = []
for model in models: for model in models:
if model.get("arena"): if model.get("arena"):
if has_access( if await has_access(
user.id, user.id,
type="read", type="read",
access_control=model.get("info", {}) access_control=model.get("info", {})
@ -1286,7 +1286,7 @@ async def get_models(
if ( if (
(user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS) (user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
or user.id == model_info.user_id or user.id == model_info.user_id
or has_access( or await has_access(
user.id, type="read", access_control=model_info.access_control user.id, type="read", access_control=model_info.access_control
) )
): ):
@ -1334,7 +1334,7 @@ async def get_models(
user.role == "user" user.role == "user"
or (user.role == "admin" and not ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS) or (user.role == "admin" and not ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
) and not BYPASS_MODEL_ACCESS_CONTROL: ) and not BYPASS_MODEL_ACCESS_CONTROL:
models = get_filtered_models(models, user) models = await get_filtered_models(models, user)
log.debug( log.debug(
f"/api/models returned filtered models accessible to the user: {json.dumps([model.get('id') for model in models])}" f"/api/models returned filtered models accessible to the user: {json.dumps([model.get('id') for model in models])}"
@ -1408,7 +1408,7 @@ async def chat_completion(
user.role != "admin" or not ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS user.role != "admin" or not ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS
): ):
try: try:
check_model_access(user, model) await check_model_access(user, model)
except Exception as e: except Exception as e:
raise e raise e
else: else:
@ -1628,7 +1628,7 @@ async def get_app_config(request: Request):
if data is not None and "id" in data: if data is not None and "id" in data:
user = await Users.get_user_by_id(data["id"]) user = await Users.get_user_by_id(data["id"])
user_count = Users.get_num_users() user_count = await Users.get_num_users()
onboarding = False onboarding = False
if user is None: if user is None:

View file

@ -101,7 +101,7 @@ class ChannelTable:
channel channel
for channel in channels for channel in channels
if channel.user_id == user_id if channel.user_id == user_id
or has_access(user_id, permission, channel.access_control) or await has_access(user_id, permission, channel.access_control)
] ]
async def get_channel_by_id(self, id: str) -> Optional[ChannelModel]: async def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:

View file

@ -143,7 +143,7 @@ class KnowledgeTable:
) )
return knowledge_bases return knowledge_bases
def get_knowledge_bases_by_user_id( async def get_knowledge_bases_by_user_id(
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write"
) -> list[KnowledgeUserModel]: ) -> list[KnowledgeUserModel]:
knowledge_bases = self.get_knowledge_bases() knowledge_bases = self.get_knowledge_bases()
@ -151,31 +151,31 @@ class KnowledgeTable:
knowledge_base knowledge_base
for knowledge_base in knowledge_bases for knowledge_base in knowledge_bases
if knowledge_base.user_id == user_id if knowledge_base.user_id == user_id
or has_access(user_id, permission, knowledge_base.access_control) or await has_access(user_id, permission, knowledge_base.access_control)
] ]
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]: async def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
try: try:
async with get_db() as db: async with get_db() as db:
knowledge = db.query(Knowledge).filter_by(id=id).first() knowledge = await db.query(Knowledge).filter_by(id=id).first()
return KnowledgeModel.model_validate(knowledge) if knowledge else None return KnowledgeModel.model_validate(knowledge) if knowledge else None
except Exception: except Exception:
return None return None
def update_knowledge_by_id( async def update_knowledge_by_id(
self, id: str, form_data: KnowledgeForm, overwrite: bool = False self, id: str, form_data: KnowledgeForm, overwrite: bool = False
) -> Optional[KnowledgeModel]: ) -> Optional[KnowledgeModel]:
try: try:
async with get_db() as db: async with get_db() as db:
knowledge = self.get_knowledge_by_id(id=id) knowledge = await self.get_knowledge_by_id(id=id)
db.query(Knowledge).filter_by(id=id).update( await db.query(Knowledge).filter_by(id=id).update(
{ {
**form_data.model_dump(), **form_data.model_dump(),
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
db.commit() await db.commit()
return self.get_knowledge_by_id(id=id) return await self.get_knowledge_by_id(id=id)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None

View file

@ -195,7 +195,7 @@ class ModelsTable:
for model in db.query(Model).filter(Model.base_model_id == None).all() for model in db.query(Model).filter(Model.base_model_id == None).all()
] ]
def get_models_by_user_id( async def get_models_by_user_id(
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write"
) -> list[ModelUserResponse]: ) -> list[ModelUserResponse]:
models = self.get_models() models = self.get_models()
@ -203,7 +203,7 @@ class ModelsTable:
model model
for model in models for model in models
if model.user_id == user_id if model.user_id == user_id
or has_access(user_id, permission, model.access_control) or await has_access(user_id, permission, model.access_control)
] ]
def get_model_by_id(self, id: str) -> Optional[ModelModel]: def get_model_by_id(self, id: str) -> Optional[ModelModel]:

View file

@ -96,20 +96,20 @@ class NoteTable:
db.commit() db.commit()
return note return note
def get_notes(self) -> list[NoteModel]: async def get_notes(self) -> list[NoteModel]:
async with get_db() as db: async with get_db() as db:
notes = db.query(Note).order_by(Note.updated_at.desc()).all() notes = await db.query(Note).order_by(Note.updated_at.desc()).all()
return [NoteModel.model_validate(note) for note in notes] return [NoteModel.model_validate(note) for note in notes]
def get_notes_by_user_id( async def get_notes_by_user_id(
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write"
) -> list[NoteModel]: ) -> list[NoteModel]:
notes = self.get_notes() notes = await self.get_notes()
return [ return [
note note
for note in notes for note in notes
if note.user_id == user_id if note.user_id == user_id
or has_access(user_id, permission, note.access_control) or await has_access(user_id, permission, note.access_control)
] ]
def get_note_by_id(self, id: str) -> Optional[NoteModel]: def get_note_by_id(self, id: str) -> Optional[NoteModel]:

View file

@ -127,7 +127,7 @@ class PromptsTable:
prompt prompt
for prompt in prompts for prompt in prompts
if prompt.user_id == user_id if prompt.user_id == user_id
or has_access(user_id, permission, prompt.access_control) or await has_access(user_id, permission, prompt.access_control)
] ]
def update_prompt_by_command( def update_prompt_by_command(

View file

@ -107,7 +107,7 @@ class ToolValves(BaseModel):
class ToolsTable: class ToolsTable:
def insert_new_tool( async def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: list[dict] self, user_id: str, form_data: ToolForm, specs: list[dict]
) -> Optional[ToolModel]: ) -> Optional[ToolModel]:
async with get_db() as db: async with get_db() as db:
@ -123,9 +123,9 @@ class ToolsTable:
try: try:
result = Tool(**tool.model_dump()) result = Tool(**tool.model_dump())
db.add(result) await db.add(result)
db.commit() await db.commit()
db.refresh(result) await db.refresh(result)
if result: if result:
return ToolModel.model_validate(result) return ToolModel.model_validate(result)
else: else:
@ -134,10 +134,10 @@ class ToolsTable:
log.exception(f"Error creating a new tool: {e}") log.exception(f"Error creating a new tool: {e}")
return None return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]: async def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try: try:
async with get_db() as db: async with get_db() as db:
tool = db.get(Tool, id) tool = await db.get(Tool, id)
return ToolModel.model_validate(tool) return ToolModel.model_validate(tool)
except Exception: except Exception:
return None return None
@ -145,7 +145,7 @@ class ToolsTable:
async def get_tools(self) -> list[ToolUserModel]: async def get_tools(self) -> list[ToolUserModel]:
async with get_db() as db: async with get_db() as db:
tools = [] tools = []
for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all(): for tool in await db.query(Tool).order_by(Tool.updated_at.desc()).all():
user = await Users.get_user_by_id(tool.user_id) user = await Users.get_user_by_id(tool.user_id)
tools.append( tools.append(
ToolUserModel.model_validate( ToolUserModel.model_validate(
@ -157,35 +157,37 @@ class ToolsTable:
) )
return tools return tools
def get_tools_by_user_id( async def get_tools_by_user_id(
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write"
) -> list[ToolUserModel]: ) -> list[ToolUserModel]:
tools = self.get_tools() tools = await self.get_tools()
return [ return [
tool tool
for tool in tools for tool in tools
if tool.user_id == user_id if tool.user_id == user_id
or has_access(user_id, permission, tool.access_control) or await has_access(user_id, permission, tool.access_control)
] ]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]: async def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try: try:
async with get_db() as db: async with get_db() as db:
tool = db.get(Tool, id) tool = await db.get(Tool, id)
return tool.valves if tool.valves else {} return tool.valves if tool.valves else {}
except Exception as e: except Exception as e:
log.exception(f"Error getting tool valves by id {id}: {e}") log.exception(f"Error getting tool valves by id {id}: {e}")
return None return None
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: async def update_tool_valves_by_id(
self, id: str, valves: dict
) -> Optional[ToolValves]:
try: try:
async with get_db() as db: async with get_db() as db:
db.query(Tool).filter_by(id=id).update( await db.query(Tool).filter_by(id=id).update(
{"valves": valves, "updated_at": int(time.time())} {"valves": valves, "updated_at": int(time.time())}
) )
db.commit() await db.commit()
return self.get_tool_by_id(id) return await self.get_tool_by_id(id)
except Exception: except Exception:
return None return None
@ -225,7 +227,7 @@ class ToolsTable:
user_settings["tools"]["valves"][id] = valves user_settings["tools"]["valves"][id] = valves
# Update the user settings in the database # Update the user settings in the database
Users.update_user_by_id(user_id, {"settings": user_settings}) await Users.update_user_by_id(user_id, {"settings": user_settings})
return user_settings["tools"]["valves"][id] return user_settings["tools"]["valves"][id]
except Exception as e: except Exception as e:
@ -234,25 +236,25 @@ class ToolsTable:
) )
return None return None
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: async def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try: try:
async with get_db() as db: async with get_db() as db:
db.query(Tool).filter_by(id=id).update( await db.query(Tool).filter_by(id=id).update(
{**updated, "updated_at": int(time.time())} {**updated, "updated_at": int(time.time())}
) )
db.commit() await db.commit()
tool = db.query(Tool).get(id) tool = await db.query(Tool).get(id)
db.refresh(tool) await db.refresh(tool)
return ToolModel.model_validate(tool) return ToolModel.model_validate(tool)
except Exception: except Exception:
return None return None
def delete_tool_by_id(self, id: str) -> bool: async def delete_tool_by_id(self, id: str) -> bool:
try: try:
async with get_db() as db: async with get_db() as db:
db.query(Tool).filter_by(id=id).delete() await db.query(Tool).filter_by(id=id).delete()
db.commit() await db.commit()
return True return True
except Exception: except Exception:

View file

@ -517,7 +517,7 @@ def get_sources_from_items(
if note and ( if note and (
user.role == "admin" user.role == "admin"
or note.user_id == user.id or note.user_id == user.id
or has_access(user.id, "read", note.access_control) or await has_access(user.id, "read", note.access_control)
): ):
# User has access to the note # User has access to the note
query_result = { query_result = {
@ -581,7 +581,7 @@ def get_sources_from_items(
if knowledge_base and ( if knowledge_base and (
user.role == "admin" user.role == "admin"
or has_access(user.id, "read", knowledge_base.access_control) or await has_access(user.id, "read", knowledge_base.access_control)
): ):
file_ids = knowledge_base.data.get("file_ids", []) file_ids = knowledge_base.data.get("file_ids", [])

View file

@ -135,7 +135,7 @@ async def update_profile(
form_data: UpdateProfileForm, session_user=Depends(get_verified_user) form_data: UpdateProfileForm, session_user=Depends(get_verified_user)
): ):
if session_user: if session_user:
user = Users.update_user_by_id( user = await Users.update_user_by_id(
session_user.id, session_user.id,
{"profile_image_url": form_data.profile_image_url, "name": form_data.name}, {"profile_image_url": form_data.profile_image_url, "name": form_data.name},
) )
@ -348,12 +348,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
if not connection_user.bind(): if not connection_user.bind():
raise HTTPException(400, "Authentication failed.") raise HTTPException(400, "Authentication failed.")
user = Users.get_user_by_email(email) user = await Users.get_user_by_email(email)
if not user: if not user:
try: try:
role = ( role = (
"admin" "admin"
if not Users.has_users() if not await Users.has_users()
else request.app.state.config.DEFAULT_USER_ROLE else request.app.state.config.DEFAULT_USER_ROLE
) )
@ -463,7 +463,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
if WEBUI_AUTH_TRUSTED_NAME_HEADER: if WEBUI_AUTH_TRUSTED_NAME_HEADER:
name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email) name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email)
if not Users.get_user_by_email(email.lower()): if not await Users.get_user_by_email(email.lower()):
await signup( await signup(
request, request,
response, response,
@ -484,10 +484,10 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
admin_email = "admin@localhost" admin_email = "admin@localhost"
admin_password = "admin" admin_password = "admin"
if Users.get_user_by_email(admin_email.lower()): if await Users.get_user_by_email(admin_email.lower()):
user = await Auths.authenticate_user(admin_email.lower(), admin_password) user = await Auths.authenticate_user(admin_email.lower(), admin_password)
else: else:
if Users.has_users(): if await Users.has_users():
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS) raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
await signup( await signup(
@ -556,7 +556,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
@router.post("/signup", response_model=SessionUserResponse) @router.post("/signup", response_model=SessionUserResponse)
async def signup(request: Request, response: Response, form_data: SignupForm): async def signup(request: Request, response: Response, form_data: SignupForm):
has_users = Users.has_users() has_users = await Users.has_users()
if WEBUI_AUTH: if WEBUI_AUTH:
if ( if (
@ -577,7 +577,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
) )
if Users.get_user_by_email(form_data.email.lower()): if await Users.get_user_by_email(form_data.email.lower()):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try: try:
@ -733,7 +733,7 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
) )
if Users.get_user_by_email(form_data.email.lower()): if await Users.get_user_by_email(form_data.email.lower()):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try: try:
@ -780,11 +780,11 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}") log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}")
if admin_email: if admin_email:
admin = Users.get_user_by_email(admin_email) admin = await Users.get_user_by_email(admin_email)
if admin: if admin:
admin_name = admin.name admin_name = admin.name
else: else:
admin = Users.get_first_user() admin = await Users.get_first_user()
if admin: if admin:
admin_email = admin.email admin_email = admin.email
admin_name = admin.name admin_name = admin.name
@ -1025,7 +1025,7 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)):
) )
api_key = create_api_key() api_key = create_api_key()
success = Users.update_user_api_key_by_id(user.id, api_key) success = await Users.update_user_api_key_by_id(user.id, api_key)
if success: if success:
return { return {
@ -1038,14 +1038,14 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)):
# delete api key # delete api key
@router.delete("/api_key", response_model=bool) @router.delete("/api_key", response_model=bool)
async def delete_api_key(user=Depends(get_current_user)): async def delete_api_key(user=Depends(get_current_user)):
success = Users.update_user_api_key_by_id(user.id, None) success = await Users.update_user_api_key_by_id(user.id, None)
return success return success
# get api key # get api key
@router.get("/api_key", response_model=ApiKey) @router.get("/api_key", response_model=ApiKey)
async def get_api_key(user=Depends(get_current_user)): async def get_api_key(user=Depends(get_current_user)):
api_key = Users.get_user_api_key_by_id(user.id) api_key = await Users.get_user_api_key_by_id(user.id)
if api_key: if api_key:
return { return {
"api_key": api_key, "api_key": api_key,

View file

@ -80,7 +80,7 @@ async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
) )
if user.role != "admin" and not has_access( if user.role != "admin" and not await has_access(
user.id, type="read", access_control=channel.access_control user.id, type="read", access_control=channel.access_control
): ):
raise HTTPException( raise HTTPException(
@ -157,7 +157,7 @@ async def get_channel_messages(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
) )
if user.role != "admin" and not has_access( if user.role != "admin" and not await has_access(
user.id, type="read", access_control=channel.access_control user.id, type="read", access_control=channel.access_control
): ):
raise HTTPException( raise HTTPException(
@ -197,7 +197,7 @@ async def get_channel_messages(
async def send_notification(name, webui_url, channel, message, active_user_ids): async def send_notification(name, webui_url, channel, message, active_user_ids):
users = get_users_with_access("read", channel.access_control) users = await get_users_with_access("read", channel.access_control)
for user in users: for user in users:
if user.id in active_user_ids: if user.id in active_user_ids:
@ -236,7 +236,7 @@ async def post_new_message(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
) )
if user.role != "admin" and not has_access( if user.role != "admin" and not await has_access(
user.id, type="read", access_control=channel.access_control user.id, type="read", access_control=channel.access_control
): ):
raise HTTPException( raise HTTPException(
@ -341,7 +341,7 @@ async def get_channel_message(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
) )
if user.role != "admin" and not has_access( if user.role != "admin" and not await has_access(
user.id, type="read", access_control=channel.access_control user.id, type="read", access_control=channel.access_control
): ):
raise HTTPException( raise HTTPException(
@ -390,7 +390,7 @@ async def get_channel_thread_messages(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
) )
if user.role != "admin" and not has_access( if user.role != "admin" and not await has_access(
user.id, type="read", access_control=channel.access_control user.id, type="read", access_control=channel.access_control
): ):
raise HTTPException( raise HTTPException(
@ -452,7 +452,9 @@ async def update_message_by_id(
if ( if (
user.role != "admin" user.role != "admin"
and message.user_id != user.id and message.user_id != user.id
and not has_access(user.id, type="read", access_control=channel.access_control) and not await has_access(
user.id, type="read", access_control=channel.access_control
)
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@ -512,7 +514,7 @@ async def add_reaction_to_message(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
) )
if user.role != "admin" and not has_access( if user.role != "admin" and not await has_access(
user.id, type="read", access_control=channel.access_control user.id, type="read", access_control=channel.access_control
): ):
raise HTTPException( raise HTTPException(
@ -578,7 +580,7 @@ async def remove_reaction_by_id_and_user_id_and_name(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
) )
if user.role != "admin" and not has_access( if user.role != "admin" and not await has_access(
user.id, type="read", access_control=channel.access_control user.id, type="read", access_control=channel.access_control
): ):
raise HTTPException( raise HTTPException(
@ -661,7 +663,9 @@ async def delete_message_by_id(
if ( if (
user.role != "admin" user.role != "admin"
and message.user_id != user.id and message.user_id != user.id
and not has_access(user.id, type="read", access_control=channel.access_control) and not await has_access(
user.id, type="read", access_control=channel.access_control
)
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()

View file

@ -440,10 +440,10 @@ async def update_function_valves_by_id(
@router.get("/id/{id}/valves/user", response_model=Optional[dict]) @router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)): async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
function = Functions.get_function_by_id(id) function = await Functions.get_function_by_id(id)
if function: if function:
try: try:
user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id) user_valves = await Functions.get_user_valves_by_id_and_user_id(id, user.id)
return user_valves return user_valves
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
@ -482,7 +482,7 @@ async def get_function_user_valves_spec_by_id(
async def update_function_user_valves_by_id( async def update_function_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user) request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
): ):
function = Functions.get_function_by_id(id) function = await Functions.get_function_by_id(id)
if function: if function:
function_module, function_type, frontmatter = get_function_module_from_cache( function_module, function_type, frontmatter = get_function_module_from_cache(
@ -495,7 +495,7 @@ async def update_function_user_valves_by_id(
try: try:
form_data = {k: v for k, v in form_data.items() if v is not None} form_data = {k: v for k, v in form_data.items() if v is not None}
user_valves = UserValves(**form_data) user_valves = UserValves(**form_data)
Functions.update_user_valves_by_id_and_user_id( await Functions.update_user_valves_by_id_and_user_id(
id, user.id, user_valves.model_dump() id, user.id, user_valves.model_dump()
) )
return user_valves.model_dump() return user_valves.model_dump()

View file

@ -90,7 +90,7 @@ async def update_group_by_id(
): ):
try: try:
if form_data.user_ids: if form_data.user_ids:
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids) form_data.user_ids = await Users.get_valid_user_ids(form_data.user_ids)
group = Groups.update_group_by_id(id, form_data) group = Groups.update_group_by_id(id, form_data)
if group: if group:
@ -119,7 +119,7 @@ async def add_user_to_group(
): ):
try: try:
if form_data.user_ids: if form_data.user_ids:
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids) form_data.user_ids = await Users.get_valid_user_ids(form_data.user_ids)
group = Groups.add_users_to_group(id, form_data.user_ids) group = Groups.add_users_to_group(id, form_data.user_ids)
if group: if group:

View file

@ -261,11 +261,11 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
if ( if (
user.role == "admin" user.role == "admin"
or knowledge.user_id == user.id or knowledge.user_id == user.id
or has_access(user.id, "read", knowledge.access_control) or await has_access(user.id, "read", knowledge.access_control)
): ):
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
files = Files.get_file_metadatas_by_ids(file_ids) files = await Files.get_file_metadatas_by_ids(file_ids)
return KnowledgeFilesResponse( return KnowledgeFilesResponse(
**knowledge.model_dump(), **knowledge.model_dump(),
@ -298,7 +298,7 @@ async def update_knowledge_by_id(
# Is the user the original creator, in a group with write access, or an admin # Is the user the original creator, in a group with write access, or an admin
if ( if (
knowledge.user_id != user.id knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control) and not await has_access(user.id, "write", knowledge.access_control)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -348,7 +348,7 @@ def add_file_to_knowledge_by_id(
if ( if (
knowledge.user_id != user.id knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control) and not await has_access(user.id, "write", knowledge.access_control)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -432,7 +432,7 @@ def update_file_from_knowledge_by_id(
if ( if (
knowledge.user_id != user.id knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control) and not await has_access(user.id, "write", knowledge.access_control)
and user.role != "admin" and user.role != "admin"
): ):
@ -503,7 +503,7 @@ def remove_file_from_knowledge_by_id(
if ( if (
knowledge.user_id != user.id knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control) and not await has_access(user.id, "write", knowledge.access_control)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -591,7 +591,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
if ( if (
knowledge.user_id != user.id knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control) and not await has_access(user.id, "write", knowledge.access_control)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -654,7 +654,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
if ( if (
knowledge.user_id != user.id knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control) and not await has_access(user.id, "write", knowledge.access_control)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -697,7 +697,7 @@ def add_files_to_knowledge_batch(
if ( if (
knowledge.user_id != user.id knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control) and not await has_access(user.id, "write", knowledge.access_control)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(

View file

@ -119,7 +119,7 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
if ( if (
(user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS) (user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
or model.user_id == user.id or model.user_id == user.id
or has_access(user.id, "read", model.access_control) or await has_access(user.id, "read", model.access_control)
): ):
return model return model
else: else:
@ -141,7 +141,7 @@ async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
if ( if (
user.role == "admin" user.role == "admin"
or model.user_id == user.id or model.user_id == user.id
or has_access(user.id, "write", model.access_control) or await has_access(user.id, "write", model.access_control)
): ):
model = Models.toggle_model_by_id(id) model = Models.toggle_model_by_id(id)
@ -185,7 +185,7 @@ async def update_model_by_id(
if ( if (
model.user_id != user.id model.user_id != user.id
and not has_access(user.id, "write", model.access_control) and not await has_access(user.id, "write", model.access_control)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -204,7 +204,7 @@ async def update_model_by_id(
@router.delete("/model/delete", response_model=bool) @router.delete("/model/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_verified_user)): async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id) model = await Models.get_model_by_id(id)
if not model: if not model:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -214,18 +214,18 @@ async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
if ( if (
user.role != "admin" user.role != "admin"
and model.user_id != user.id and model.user_id != user.id
and not has_access(user.id, "write", model.access_control) and not await has_access(user.id, "write", model.access_control)
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED,
) )
result = Models.delete_model_by_id(id) result = await Models.delete_model_by_id(id)
return result return result
@router.delete("/delete/all", response_model=bool) @router.delete("/delete/all", response_model=bool)
async def delete_all_models(user=Depends(get_admin_user)): async def delete_all_models(user=Depends(get_admin_user)):
result = Models.delete_all_models() result = await Models.delete_all_models()
return result return result

View file

@ -133,7 +133,11 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
if user.role != "admin" and ( if user.role != "admin" and (
user.id != note.user_id user.id != note.user_id
and (not has_access(user.id, type="read", access_control=note.access_control)) and (
not await has_access(
user.id, type="read", access_control=note.access_control
)
)
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@ -167,7 +171,9 @@ async def update_note_by_id(
if user.role != "admin" and ( if user.role != "admin" and (
user.id != note.user_id user.id != note.user_id
and not has_access(user.id, type="write", access_control=note.access_control) and not await has_access(
user.id, type="write", access_control=note.access_control
)
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@ -212,7 +218,9 @@ async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified
if user.role != "admin" and ( if user.role != "admin" and (
user.id != note.user_id user.id != note.user_id
and not has_access(user.id, type="write", access_control=note.access_control) and not await has_access(
user.id, type="write", access_control=note.access_control
)
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()

View file

@ -443,8 +443,10 @@ async def get_filtered_models(models, user):
for model in models.get("models", []): for model in models.get("models", []):
model_info = Models.get_model_by_id(model["model"]) model_info = Models.get_model_by_id(model["model"])
if model_info: if model_info:
if user.id == model_info.user_id or has_access( if user.id == model_info.user_id or (
user.id, type="read", access_control=model_info.access_control await has_access(
user.id, type="read", access_control=model_info.access_control
)
): ):
filtered_models.append(model) filtered_models.append(model)
return filtered_models return filtered_models
@ -1336,8 +1338,10 @@ async def generate_chat_completion(
if not bypass_filter and user.role == "user": if not bypass_filter and user.role == "user":
if not ( if not (
user.id == model_info.user_id user.id == model_info.user_id
or has_access( or (
user.id, type="read", access_control=model_info.access_control await has_access(
user.id, type="read", access_control=model_info.access_control
)
) )
): ):
raise HTTPException( raise HTTPException(
@ -1442,8 +1446,10 @@ async def generate_openai_completion(
if user.role == "user": if user.role == "user":
if not ( if not (
user.id == model_info.user_id user.id == model_info.user_id
or has_access( or (
user.id, type="read", access_control=model_info.access_control await has_access(
user.id, type="read", access_control=model_info.access_control
)
) )
): ):
raise HTTPException( raise HTTPException(
@ -1525,8 +1531,10 @@ async def generate_openai_chat_completion(
if user.role == "user": if user.role == "user":
if not ( if not (
user.id == model_info.user_id user.id == model_info.user_id
or has_access( or (
user.id, type="read", access_control=model_info.access_control await has_access(
user.id, type="read", access_control=model_info.access_control
)
) )
): ):
raise HTTPException( raise HTTPException(
@ -1623,8 +1631,10 @@ async def get_openai_models(
for model in models: for model in models:
model_info = Models.get_model_by_id(model["id"]) model_info = Models.get_model_by_id(model["id"])
if model_info: if model_info:
if user.id == model_info.user_id or has_access( if user.id == model_info.user_id or (
user.id, type="read", access_control=model_info.access_control await has_access(
user.id, type="read", access_control=model_info.access_control
)
): ):
filtered_models.append(model) filtered_models.append(model)
models = filtered_models models = filtered_models

View file

@ -385,8 +385,10 @@ async def get_filtered_models(models, user):
for model in models.get("data", []): for model in models.get("data", []):
model_info = Models.get_model_by_id(model["id"]) model_info = Models.get_model_by_id(model["id"])
if model_info: if model_info:
if user.id == model_info.user_id or has_access( if user.id == model_info.user_id or (
user.id, type="read", access_control=model_info.access_control await has_access(
user.id, type="read", access_control=model_info.access_control
)
): ):
filtered_models.append(model) filtered_models.append(model)
return filtered_models return filtered_models
@ -756,7 +758,7 @@ async def generate_chat_completion(
if not bypass_filter and user.role == "user": if not bypass_filter and user.role == "user":
if not ( if not (
user.id == model_info.user_id user.id == model_info.user_id
or has_access( or await has_access(
user.id, type="read", access_control=model_info.access_control user.id, type="read", access_control=model_info.access_control
) )
): ):

View file

@ -79,13 +79,13 @@ async def create_new_prompt(
@router.get("/command/{command}", response_model=Optional[PromptModel]) @router.get("/command/{command}", response_model=Optional[PromptModel])
async def get_prompt_by_command(command: str, user=Depends(get_verified_user)): async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
prompt = Prompts.get_prompt_by_command(f"/{command}") prompt = await Prompts.get_prompt_by_command(f"/{command}")
if prompt: if prompt:
if ( if (
user.role == "admin" user.role == "admin"
or prompt.user_id == user.id or prompt.user_id == user.id
or has_access(user.id, "read", prompt.access_control) or await has_access(user.id, "read", prompt.access_control)
): ):
return prompt return prompt
else: else:
@ -106,7 +106,7 @@ async def update_prompt_by_command(
form_data: PromptForm, form_data: PromptForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
prompt = Prompts.get_prompt_by_command(f"/{command}") prompt = await Prompts.get_prompt_by_command(f"/{command}")
if not prompt: if not prompt:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -116,7 +116,7 @@ async def update_prompt_by_command(
# Is the user the original creator, in a group with write access, or an admin # Is the user the original creator, in a group with write access, or an admin
if ( if (
prompt.user_id != user.id prompt.user_id != user.id
and not has_access(user.id, "write", prompt.access_control) and not await has_access(user.id, "write", prompt.access_control)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -124,7 +124,7 @@ async def update_prompt_by_command(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) prompt = await Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt: if prompt:
return prompt return prompt
else: else:
@ -150,7 +150,7 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)
if ( if (
prompt.user_id != user.id prompt.user_id != user.id
and not has_access(user.id, "write", prompt.access_control) and not await has_access(user.id, "write", prompt.access_control)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(

View file

@ -297,7 +297,7 @@ def get_scim_auth(
) )
def user_to_scim(user: UserModel, request: Request) -> SCIMUser: async def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
"""Convert internal User model to SCIM User""" """Convert internal User model to SCIM User"""
# Parse display name into name components # Parse display name into name components
name_parts = user.name.split(" ", 1) if user.name else ["", ""] name_parts = user.name.split(" ", 1) if user.name else ["", ""]
@ -346,7 +346,7 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
) )
def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup: async def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
"""Convert internal Group model to SCIM Group""" """Convert internal Group model to SCIM Group"""
members = [] members = []
for user_id in group.user_ids: for user_id in group.user_ids:
@ -493,20 +493,20 @@ async def get_users(
# In production, you'd want a more robust filter parser # In production, you'd want a more robust filter parser
if "userName eq" in filter: if "userName eq" in filter:
email = filter.split('"')[1] email = filter.split('"')[1]
user = Users.get_user_by_email(email) user = await Users.get_user_by_email(email)
users_list = [user] if user else [] users_list = [user] if user else []
total = 1 if user else 0 total = 1 if user else 0
else: else:
response = Users.get_users(skip=skip, limit=limit) response = await Users.get_users(skip=skip, limit=limit)
users_list = response["users"] users_list = response["users"]
total = response["total"] total = response["total"]
else: else:
response = Users.get_users(skip=skip, limit=limit) response = await Users.get_users(skip=skip, limit=limit)
users_list = response["users"] users_list = response["users"]
total = response["total"] total = response["total"]
# Convert to SCIM format # Convert to SCIM format
scim_users = [user_to_scim(user, request) for user in users_list] scim_users = [await user_to_scim(user, request) for user in users_list]
return SCIMListResponse( return SCIMListResponse(
totalResults=total, totalResults=total,
@ -529,7 +529,7 @@ async def get_user(
status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found" status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found"
) )
return user_to_scim(user, request) return await user_to_scim(user, request)
@router.post("/Users", response_model=SCIMUser, status_code=status.HTTP_201_CREATED) @router.post("/Users", response_model=SCIMUser, status_code=status.HTTP_201_CREATED)
@ -540,7 +540,7 @@ async def create_user(
): ):
"""Create SCIM User""" """Create SCIM User"""
# Check if user already exists # Check if user already exists
existing_user = Users.get_user_by_email(user_data.userName) existing_user = await Users.get_user_by_email(user_data.userName)
if existing_user: if existing_user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
@ -579,7 +579,7 @@ async def create_user(
detail="Failed to create user", detail="Failed to create user",
) )
return user_to_scim(new_user, request) return await user_to_scim(new_user, request)
@router.put("/Users/{user_id}", response_model=SCIMUser) @router.put("/Users/{user_id}", response_model=SCIMUser)
@ -623,14 +623,14 @@ async def update_user(
update_data["profile_image_url"] = user_data.photos[0].value update_data["profile_image_url"] = user_data.photos[0].value
# Update user # Update user
updated_user = Users.update_user_by_id(user_id, update_data) updated_user = await Users.update_user_by_id(user_id, update_data)
if not updated_user: if not updated_user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update user", detail="Failed to update user",
) )
return user_to_scim(updated_user, request) return await user_to_scim(updated_user, request)
@router.patch("/Users/{user_id}", response_model=SCIMUser) @router.patch("/Users/{user_id}", response_model=SCIMUser)
@ -669,7 +669,7 @@ async def patch_user(
# Update user # Update user
if update_data: if update_data:
updated_user = Users.update_user_by_id(user_id, update_data) updated_user = await Users.update_user_by_id(user_id, update_data)
if not updated_user: if not updated_user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -678,7 +678,7 @@ async def patch_user(
else: else:
updated_user = user updated_user = user
return user_to_scim(updated_user, request) return await user_to_scim(updated_user, request)
@router.delete("/Users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/Users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
@ -695,7 +695,7 @@ async def delete_user(
detail=f"User {user_id} not found", detail=f"User {user_id} not found",
) )
success = Users.delete_user_by_id(user_id) success = await Users.delete_user_by_id(user_id)
if not success: if not success:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -716,7 +716,7 @@ async def get_groups(
): ):
"""List SCIM Groups""" """List SCIM Groups"""
# Get all groups # Get all groups
groups_list = Groups.get_groups() groups_list = await Groups.get_groups()
# Apply pagination # Apply pagination
total = len(groups_list) total = len(groups_list)
@ -725,7 +725,7 @@ async def get_groups(
paginated_groups = groups_list[start:end] paginated_groups = groups_list[start:end]
# Convert to SCIM format # Convert to SCIM format
scim_groups = [group_to_scim(group, request) for group in paginated_groups] scim_groups = [await group_to_scim(group, request) for group in paginated_groups]
return SCIMListResponse( return SCIMListResponse(
totalResults=total, totalResults=total,
@ -742,14 +742,14 @@ async def get_group(
_: bool = Depends(get_scim_auth), _: bool = Depends(get_scim_auth),
): ):
"""Get SCIM Group by ID""" """Get SCIM Group by ID"""
group = Groups.get_group_by_id(group_id) group = await Groups.get_group_by_id(group_id)
if not group: if not group:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Group {group_id} not found", detail=f"Group {group_id} not found",
) )
return group_to_scim(group, request) return await group_to_scim(group, request)
@router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED) @router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED)
@ -774,14 +774,14 @@ async def create_group(
) )
# Need to get the creating user's ID - we'll use the first admin # Need to get the creating user's ID - we'll use the first admin
admin_user = Users.get_super_admin_user() admin_user = await Users.get_super_admin_user()
if not admin_user: if not admin_user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="No admin user found", detail="No admin user found",
) )
new_group = Groups.insert_new_group(admin_user.id, form) new_group = await Groups.insert_new_group(admin_user.id, form)
if not new_group: if not new_group:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -797,10 +797,10 @@ async def create_group(
description=new_group.description, description=new_group.description,
user_ids=member_ids, user_ids=member_ids,
) )
Groups.update_group_by_id(new_group.id, update_form) await Groups.update_group_by_id(new_group.id, update_form)
new_group = Groups.get_group_by_id(new_group.id) new_group = await Groups.get_group_by_id(new_group.id)
return group_to_scim(new_group, request) return await group_to_scim(new_group, request)
@router.put("/Groups/{group_id}", response_model=SCIMGroup) @router.put("/Groups/{group_id}", response_model=SCIMGroup)
@ -839,7 +839,7 @@ async def update_group(
detail="Failed to update group", detail="Failed to update group",
) )
return group_to_scim(updated_group, request) return await group_to_scim(updated_group, request)
@router.patch("/Groups/{group_id}", response_model=SCIMGroup) @router.patch("/Groups/{group_id}", response_model=SCIMGroup)
@ -899,7 +899,7 @@ async def patch_group(
detail="Failed to update group", detail="Failed to update group",
) )
return group_to_scim(updated_group, request) return await group_to_scim(updated_group, request)
@router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT)

View file

@ -49,7 +49,7 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
request.app.state.config.TOOL_SERVER_CONNECTIONS request.app.state.config.TOOL_SERVER_CONNECTIONS
) )
tools = Tools.get_tools() tools = await Tools.get_tools()
for server in request.app.state.TOOL_SERVERS: for server in request.app.state.TOOL_SERVERS:
tools.append( tools.append(
ToolUserResponse( ToolUserResponse(
@ -83,7 +83,7 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
tool tool
for tool in tools for tool in tools
if tool.user_id == user.id if tool.user_id == user.id
or has_access(user.id, "read", tool.access_control) or await has_access(user.id, "read", tool.access_control)
] ]
return tools return tools
@ -96,9 +96,9 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
@router.get("/list", response_model=list[ToolUserResponse]) @router.get("/list", response_model=list[ToolUserResponse])
async def get_tool_list(user=Depends(get_verified_user)): async def get_tool_list(user=Depends(get_verified_user)):
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS: if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
tools = Tools.get_tools() tools = await Tools.get_tools()
else: else:
tools = Tools.get_tools_by_user_id(user.id, "write") tools = await Tools.get_tools_by_user_id(user.id, "write")
return tools return tools
@ -184,7 +184,7 @@ async def load_tool_from_url(
@router.get("/export", response_model=list[ToolModel]) @router.get("/export", response_model=list[ToolModel])
async def export_tools(user=Depends(get_admin_user)): async def export_tools(user=Depends(get_admin_user)):
tools = Tools.get_tools() tools = await Tools.get_tools()
return tools return tools
@ -215,11 +215,11 @@ async def create_new_tools(
form_data.id = form_data.id.lower() form_data.id = form_data.id.lower()
tools = Tools.get_tool_by_id(form_data.id) tools = await Tools.get_tool_by_id(form_data.id)
if tools is None: if tools is None:
try: try:
form_data.content = replace_imports(form_data.content) form_data.content = replace_imports(form_data.content)
tool_module, frontmatter = load_tool_module_by_id( tool_module, frontmatter = await load_tool_module_by_id(
form_data.id, content=form_data.content form_data.id, content=form_data.content
) )
form_data.meta.manifest = frontmatter form_data.meta.manifest = frontmatter
@ -228,7 +228,7 @@ async def create_new_tools(
TOOLS[form_data.id] = tool_module TOOLS[form_data.id] = tool_module
specs = get_tool_specs(TOOLS[form_data.id]) specs = get_tool_specs(TOOLS[form_data.id])
tools = Tools.insert_new_tool(user.id, form_data, specs) tools = await Tools.insert_new_tool(user.id, form_data, specs)
tool_cache_dir = CACHE_DIR / "tools" / form_data.id tool_cache_dir = CACHE_DIR / "tools" / form_data.id
tool_cache_dir.mkdir(parents=True, exist_ok=True) tool_cache_dir.mkdir(parents=True, exist_ok=True)
@ -260,13 +260,13 @@ async def create_new_tools(
@router.get("/id/{id}", response_model=Optional[ToolModel]) @router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_tools_by_id(id: str, user=Depends(get_verified_user)): async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
tools = Tools.get_tool_by_id(id) tools = await Tools.get_tool_by_id(id)
if tools: if tools:
if ( if (
user.role == "admin" user.role == "admin"
or tools.user_id == user.id or tools.user_id == user.id
or has_access(user.id, "read", tools.access_control) or await has_access(user.id, "read", tools.access_control)
): ):
return tools return tools
else: else:
@ -288,7 +288,7 @@ async def update_tools_by_id(
form_data: ToolForm, form_data: ToolForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
tools = Tools.get_tool_by_id(id) tools = await Tools.get_tool_by_id(id)
if not tools: if not tools:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -298,7 +298,7 @@ async def update_tools_by_id(
# Is the user the original creator, in a group with write access, or an admin # Is the user the original creator, in a group with write access, or an admin
if ( if (
tools.user_id != user.id tools.user_id != user.id
and not has_access(user.id, "write", tools.access_control) and not await has_access(user.id, "write", tools.access_control)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -308,7 +308,9 @@ async def update_tools_by_id(
try: try:
form_data.content = replace_imports(form_data.content) form_data.content = replace_imports(form_data.content)
tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content) tool_module, frontmatter = await load_tool_module_by_id(
id, content=form_data.content
)
form_data.meta.manifest = frontmatter form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS TOOLS = request.app.state.TOOLS
@ -322,7 +324,7 @@ async def update_tools_by_id(
} }
log.debug(updated) log.debug(updated)
tools = Tools.update_tool_by_id(id, updated) tools = await Tools.update_tool_by_id(id, updated)
if tools: if tools:
return tools return tools
@ -348,7 +350,7 @@ async def update_tools_by_id(
async def delete_tools_by_id( async def delete_tools_by_id(
request: Request, id: str, user=Depends(get_verified_user) request: Request, id: str, user=Depends(get_verified_user)
): ):
tools = Tools.get_tool_by_id(id) tools = await Tools.get_tool_by_id(id)
if not tools: if not tools:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -357,7 +359,7 @@ async def delete_tools_by_id(
if ( if (
tools.user_id != user.id tools.user_id != user.id
and not has_access(user.id, "write", tools.access_control) and not await has_access(user.id, "write", tools.access_control)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -365,7 +367,7 @@ async def delete_tools_by_id(
detail=ERROR_MESSAGES.UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED,
) )
result = Tools.delete_tool_by_id(id) result = await Tools.delete_tool_by_id(id)
if result: if result:
TOOLS = request.app.state.TOOLS TOOLS = request.app.state.TOOLS
if id in TOOLS: if id in TOOLS:
@ -381,10 +383,10 @@ async def delete_tools_by_id(
@router.get("/id/{id}/valves", response_model=Optional[dict]) @router.get("/id/{id}/valves", response_model=Optional[dict])
async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)): async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
tools = Tools.get_tool_by_id(id) tools = await Tools.get_tool_by_id(id)
if tools: if tools:
try: try:
valves = Tools.get_tool_valves_by_id(id) valves = await Tools.get_tool_valves_by_id(id)
return valves return valves
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
@ -407,12 +409,12 @@ async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
async def get_tools_valves_spec_by_id( async def get_tools_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user) request: Request, id: str, user=Depends(get_verified_user)
): ):
tools = Tools.get_tool_by_id(id) tools = await Tools.get_tool_by_id(id)
if tools: if tools:
if id in request.app.state.TOOLS: if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id] tools_module = request.app.state.TOOLS[id]
else: else:
tools_module, _ = load_tool_module_by_id(id) tools_module, _ = await load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "Valves"): if hasattr(tools_module, "Valves"):
@ -435,7 +437,7 @@ async def get_tools_valves_spec_by_id(
async def update_tools_valves_by_id( async def update_tools_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user) request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
): ):
tools = Tools.get_tool_by_id(id) tools = await Tools.get_tool_by_id(id)
if not tools: if not tools:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -444,7 +446,7 @@ async def update_tools_valves_by_id(
if ( if (
tools.user_id != user.id tools.user_id != user.id
and not has_access(user.id, "write", tools.access_control) and not await has_access(user.id, "write", tools.access_control)
and user.role != "admin" and user.role != "admin"
): ):
raise HTTPException( raise HTTPException(
@ -455,7 +457,7 @@ async def update_tools_valves_by_id(
if id in request.app.state.TOOLS: if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id] tools_module = request.app.state.TOOLS[id]
else: else:
tools_module, _ = load_tool_module_by_id(id) tools_module, _ = await load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module request.app.state.TOOLS[id] = tools_module
if not hasattr(tools_module, "Valves"): if not hasattr(tools_module, "Valves"):
@ -468,7 +470,7 @@ async def update_tools_valves_by_id(
try: try:
form_data = {k: v for k, v in form_data.items() if v is not None} form_data = {k: v for k, v in form_data.items() if v is not None}
valves = Valves(**form_data) valves = Valves(**form_data)
Tools.update_tool_valves_by_id(id, valves.model_dump()) await Tools.update_tool_valves_by_id(id, valves.model_dump())
return valves.model_dump() return valves.model_dump()
except Exception as e: except Exception as e:
log.exception(f"Failed to update tool valves by id {id}: {e}") log.exception(f"Failed to update tool valves by id {id}: {e}")
@ -485,10 +487,10 @@ async def update_tools_valves_by_id(
@router.get("/id/{id}/valves/user", response_model=Optional[dict]) @router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)): async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
tools = Tools.get_tool_by_id(id) tools = await Tools.get_tool_by_id(id)
if tools: if tools:
try: try:
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id) user_valves = await Tools.get_user_valves_by_id_and_user_id(id, user.id)
return user_valves return user_valves
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
@ -506,12 +508,12 @@ async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
async def get_tools_user_valves_spec_by_id( async def get_tools_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user) request: Request, id: str, user=Depends(get_verified_user)
): ):
tools = Tools.get_tool_by_id(id) tools = await Tools.get_tool_by_id(id)
if tools: if tools:
if id in request.app.state.TOOLS: if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id] tools_module = request.app.state.TOOLS[id]
else: else:
tools_module, _ = load_tool_module_by_id(id) tools_module, _ = await load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "UserValves"): if hasattr(tools_module, "UserValves"):
@ -529,13 +531,13 @@ async def get_tools_user_valves_spec_by_id(
async def update_tools_user_valves_by_id( async def update_tools_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user) request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
): ):
tools = Tools.get_tool_by_id(id) tools = await Tools.get_tool_by_id(id)
if tools: if tools:
if id in request.app.state.TOOLS: if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id] tools_module = request.app.state.TOOLS[id]
else: else:
tools_module, _ = load_tool_module_by_id(id) tools_module, _ = await load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "UserValves"): if hasattr(tools_module, "UserValves"):
@ -544,7 +546,7 @@ async def update_tools_user_valves_by_id(
try: try:
form_data = {k: v for k, v in form_data.items() if v is not None} form_data = {k: v for k, v in form_data.items() if v is not None}
user_valves = UserValves(**form_data) user_valves = UserValves(**form_data)
Tools.update_user_valves_by_id_and_user_id( await Tools.update_user_valves_by_id_and_user_id(
id, user.id, user_valves.model_dump() id, user.id, user_valves.model_dump()
) )
return user_valves.model_dump() return user_valves.model_dump()

View file

@ -88,14 +88,14 @@ async def get_users(
if direction: if direction:
filter["direction"] = direction filter["direction"] = direction
return Users.get_users(filter=filter, skip=skip, limit=limit) return await Users.get_users(filter=filter, skip=skip, limit=limit)
@router.get("/all", response_model=UserInfoListResponse) @router.get("/all", response_model=UserInfoListResponse)
async def get_all_users( async def get_all_users(
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
return Users.get_users() return await Users.get_users()
############################ ############################
@ -205,7 +205,7 @@ async def update_default_user_permissions(
@router.get("/user/settings", response_model=Optional[UserSettings]) @router.get("/user/settings", response_model=Optional[UserSettings])
async def get_user_settings_by_session_user(user=Depends(get_verified_user)): async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
user = Users.get_user_by_id(user.id) user = await Users.get_user_by_id(user.id)
if user: if user:
return user.settings return user.settings
else: else:
@ -237,7 +237,7 @@ async def update_user_settings_by_session_user(
# If the user is not an admin and does not have permission to use tool servers, remove the key # If the user is not an admin and does not have permission to use tool servers, remove the key
updated_user_settings["ui"].pop("toolServers", None) updated_user_settings["ui"].pop("toolServers", None)
user = Users.update_user_settings_by_id(user.id, updated_user_settings) user = await Users.update_user_settings_by_id(user.id, updated_user_settings)
if user: if user:
return user.settings return user.settings
else: else:
@ -254,7 +254,7 @@ async def update_user_settings_by_session_user(
@router.get("/user/info", response_model=Optional[dict]) @router.get("/user/info", response_model=Optional[dict])
async def get_user_info_by_session_user(user=Depends(get_verified_user)): async def get_user_info_by_session_user(user=Depends(get_verified_user)):
user = Users.get_user_by_id(user.id) user = await Users.get_user_by_id(user.id)
if user: if user:
return user.info return user.info
else: else:
@ -273,12 +273,14 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)):
async def update_user_info_by_session_user( async def update_user_info_by_session_user(
form_data: dict, user=Depends(get_verified_user) form_data: dict, user=Depends(get_verified_user)
): ):
user = Users.get_user_by_id(user.id) user = await Users.get_user_by_id(user.id)
if user: if user:
if user.info is None: if user.info is None:
user.info = {} user.info = {}
user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}}) user = await Users.update_user_by_id(
user.id, {"info": {**user.info, **form_data}}
)
if user: if user:
return user.info return user.info
else: else:
@ -343,7 +345,7 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
@router.get("/{user_id}/profile/image") @router.get("/{user_id}/profile/image")
async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_user)): async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_user)):
user = Users.get_user_by_id(user_id) user = await Users.get_user_by_id(user_id)
if user: if user:
if user.profile_image_url: if user.profile_image_url:
# check if it's url or base64 # check if it's url or base64
@ -398,7 +400,7 @@ async def update_user_by_id(
): ):
# Prevent modification of the primary admin user by other admins # Prevent modification of the primary admin user by other admins
try: try:
first_user = Users.get_first_user() first_user = await Users.get_first_user()
if first_user: if first_user:
if user_id == first_user.id: if user_id == first_user.id:
if session_user.id != user_id: if session_user.id != user_id:
@ -472,7 +474,7 @@ async def update_user_by_id(
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
# Prevent deletion of the primary admin user # Prevent deletion of the primary admin user
try: try:
first_user = Users.get_first_user() first_user = await Users.get_first_user()
if first_user and user_id == first_user.id: if first_user and user_id == first_user.id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,

View file

@ -263,7 +263,7 @@ async def connect(sid, environ, auth):
data = decode_token(auth["token"]) data = decode_token(auth["token"])
if data is not None and "id" in data: if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"]) user = await Users.get_user_by_id(data["id"])
if user: if user:
SESSION_POOL[sid] = user.model_dump() SESSION_POOL[sid] = user.model_dump()
@ -284,7 +284,7 @@ async def user_join(sid, data):
if data is None or "id" not in data: if data is None or "id" not in data:
return return
user = Users.get_user_by_id(data["id"]) user = await Users.get_user_by_id(data["id"])
if not user: if not user:
return return
@ -312,7 +312,7 @@ async def join_channel(sid, data):
if data is None or "id" not in data: if data is None or "id" not in data:
return return
user = Users.get_user_by_id(data["id"]) user = await Users.get_user_by_id(data["id"])
if not user: if not user:
return return
@ -333,7 +333,7 @@ async def join_note(sid, data):
if token_data is None or "id" not in token_data: if token_data is None or "id" not in token_data:
return return
user = Users.get_user_by_id(token_data["id"]) user = await Users.get_user_by_id(token_data["id"])
if not user: if not user:
return return
@ -345,7 +345,9 @@ async def join_note(sid, data):
if ( if (
user.role != "admin" user.role != "admin"
and user.id != note.user_id and user.id != note.user_id
and not has_access(user.id, type="read", access_control=note.access_control) and not await has_access(
user.id, type="read", access_control=note.access_control
)
): ):
log.error(f"User {user.id} does not have access to note {data['note_id']}") log.error(f"User {user.id} does not have access to note {data['note_id']}")
return return
@ -400,7 +402,7 @@ async def ydoc_document_join(sid, data):
if ( if (
user.get("role") != "admin" user.get("role") != "admin"
and user.get("id") != note.user_id and user.get("id") != note.user_id
and not has_access( and not await has_access(
user.get("id"), type="read", access_control=note.access_control user.get("id"), type="read", access_control=note.access_control
) )
): ):
@ -470,14 +472,14 @@ async def document_save_handler(document_id, data, user):
if ( if (
user.get("role") != "admin" user.get("role") != "admin"
and user.get("id") != note.user_id and user.get("id") != note.user_id
and not has_access( and not await has_access(
user.get("id"), type="read", access_control=note.access_control user.get("id"), type="read", access_control=note.access_control
) )
): ):
log.error(f"User {user.get('id')} does not have access to note {note_id}") log.error(f"User {user.get('id')} does not have access to note {note_id}")
return return
Notes.update_note_by_id(note_id, NoteUpdateForm(data=data)) await Notes.update_note_by_id(note_id, NoteUpdateForm(data=data))
@sio.on("ydoc:document:state") @sio.on("ydoc:document:state")

View file

@ -107,7 +107,7 @@ def has_permission(
return get_permission(default_permissions, permission_hierarchy) return get_permission(default_permissions, permission_hierarchy)
def has_access( async def has_access(
user_id: str, user_id: str,
type: str = "write", type: str = "write",
access_control: Optional[dict] = None, access_control: Optional[dict] = None,
@ -115,7 +115,7 @@ def has_access(
if access_control is None: if access_control is None:
return type == "read" return type == "read"
user_groups = Groups.get_groups_by_member_id(user_id) user_groups = await Groups.get_groups_by_member_id(user_id)
user_group_ids = [group.id for group in user_groups] user_group_ids = [group.id for group in user_groups]
permission_access = access_control.get(type, {}) permission_access = access_control.get(type, {})
permitted_group_ids = permission_access.get("group_ids", []) permitted_group_ids = permission_access.get("group_ids", [])
@ -127,11 +127,11 @@ def has_access(
# Get all users with access to a resource # Get all users with access to a resource
def get_users_with_access( async def get_users_with_access(
type: str = "write", access_control: Optional[dict] = None type: str = "write", access_control: Optional[dict] = None
) -> List[UserModel]: ) -> List[UserModel]:
if access_control is None: if access_control is None:
return Users.get_users() return await Users.get_users()
permission_access = access_control.get(type, {}) permission_access = access_control.get(type, {})
permitted_group_ids = permission_access.get("group_ids", []) permitted_group_ids = permission_access.get("group_ids", [])
@ -140,8 +140,8 @@ def get_users_with_access(
user_ids_with_access = set(permitted_user_ids) user_ids_with_access = set(permitted_user_ids)
for group_id in permitted_group_ids: for group_id in permitted_group_ids:
group_user_ids = Groups.get_group_user_ids_by_id(group_id) group_user_ids = await Groups.get_group_user_ids_by_id(group_id)
if group_user_ids: if group_user_ids:
user_ids_with_access.update(group_user_ids) user_ids_with_access.update(group_user_ids)
return Users.get_users_by_user_ids(list(user_ids_with_access)) return await Users.get_users_by_user_ids(list(user_ids_with_access))

View file

@ -206,7 +206,7 @@ def get_http_authorization_cred(auth_header: Optional[str]):
return None return None
def get_current_user( async def get_current_user(
request: Request, request: Request,
response: Response, response: Response,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
@ -270,7 +270,7 @@ def get_current_user(
) )
if data is not None and "id" in data: if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"]) user = await Users.get_user_by_id(data["id"])
if user is None: if user is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -303,6 +303,7 @@ def get_current_user(
# Refresh the user's last active timestamp asynchronously # Refresh the user's last active timestamp asynchronously
# to prevent blocking the request # to prevent blocking the request
if background_tasks: if background_tasks:
background_tasks.add_task(Users.update_user_last_active_by_id, user.id) background_tasks.add_task(Users.update_user_last_active_by_id, user.id)
return user return user
else: else:
@ -312,8 +313,8 @@ def get_current_user(
) )
def get_current_user_by_api_key(api_key: str): async def get_current_user_by_api_key(api_key: str):
user = Users.get_user_by_api_key(api_key) user = await Users.get_user_by_api_key(api_key)
if user is None: if user is None:
raise HTTPException( raise HTTPException(
@ -329,7 +330,7 @@ def get_current_user_by_api_key(api_key: str):
current_span.set_attribute("client.user.role", user.role) current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "api_key") current_span.set_attribute("client.auth.type", "api_key")
Users.update_user_last_active_by_id(user.id) await Users.update_user_last_active_by_id(user.id)
return user return user

View file

@ -199,7 +199,7 @@ async def generate_chat_completion(
# Check if user has access to the model # Check if user has access to the model
if not bypass_filter and user.role == "user": if not bypass_filter and user.role == "user":
try: try:
check_model_access(user, model) await check_model_access(user, model)
except Exception as e: except Exception as e:
raise e raise e
@ -424,8 +424,10 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
try: try:
if hasattr(function_module, "UserValves"): if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves( __user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id( **(
action_id, user.id await Functions.get_user_valves_by_id_and_user_id(
action_id, user.id
)
) )
) )
except Exception as e: except Exception as e:

View file

@ -70,7 +70,7 @@ async def generate_embeddings(
# Access filtering # Access filtering
if not getattr(request.state, "direct", False): if not getattr(request.state, "direct", False):
if not bypass_filter and user.role == "user": if not bypass_filter and user.role == "user":
check_model_access(user, model) await check_model_access(user, model)
# Ollama backend # Ollama backend
if model.get("owned_by") == "ollama": if model.get("owned_by") == "ollama":

View file

@ -109,8 +109,10 @@ async def process_filter_functions(
if hasattr(function_module, "UserValves"): if hasattr(function_module, "UserValves"):
try: try:
params["__user__"]["valves"] = function_module.UserValves( params["__user__"]["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id( **(
filter_id, params["__user__"]["id"] await Functions.get_user_valves_by_id_and_user_id(
filter_id, params["__user__"]["id"]
)
) )
) )
except Exception as e: except Exception as e:

View file

@ -311,9 +311,9 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
return models return models
def check_model_access(user, model): async def check_model_access(user, model):
if model.get("arena"): if model.get("arena"):
if not has_access( if not await has_access(
user.id, user.id,
type="read", type="read",
access_control=model.get("info", {}) access_control=model.get("info", {})
@ -322,12 +322,12 @@ def check_model_access(user, model):
): ):
raise Exception("Model not found") raise Exception("Model not found")
else: else:
model_info = Models.get_model_by_id(model.get("id")) model_info = await Models.get_model_by_id(model.get("id"))
if not model_info: if not model_info:
raise Exception("Model not found") raise Exception("Model not found")
elif not ( elif not (
user.id == model_info.user_id user.id == model_info.user_id
or has_access( or await has_access(
user.id, type="read", access_control=model_info.access_control user.id, type="read", access_control=model_info.access_control
) )
): ):

View file

@ -89,8 +89,8 @@ class OAuthManager:
def get_client(self, provider_name): def get_client(self, provider_name):
return self.oauth.create_client(provider_name) return self.oauth.create_client(provider_name)
def get_user_role(self, user, user_data): async def get_user_role(self, user, user_data):
user_count = Users.get_num_users() user_count = await Users.get_num_users()
if user and user_count == 1: if user and user_count == 1:
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login # If the user is the only user, assign the role "admin" - actually repairs role for single user on login
log.debug("Assigning the only user the admin role") log.debug("Assigning the only user the admin role")
@ -147,7 +147,7 @@ class OAuthManager:
return role return role
def update_user_groups(self, user, user_data, default_permissions): async def update_user_groups(self, user, user_data, default_permissions):
log.debug("Running OAUTH Group management") log.debug("Running OAUTH Group management")
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
@ -172,8 +172,10 @@ class OAuthManager:
else: else:
user_oauth_groups = [] user_oauth_groups = []
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id) user_current_groups: list[GroupModel] = await Groups.get_groups_by_member_id(
all_available_groups: list[GroupModel] = Groups.get_groups() user.id
)
all_available_groups: list[GroupModel] = await Groups.get_groups()
# Create groups if they don't exist and creation is enabled # Create groups if they don't exist and creation is enabled
if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION: if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
@ -181,7 +183,7 @@ class OAuthManager:
all_group_names = {g.name for g in all_available_groups} all_group_names = {g.name for g in all_available_groups}
groups_created = False groups_created = False
# Determine creator ID: Prefer admin, fallback to current user if no admin exists # Determine creator ID: Prefer admin, fallback to current user if no admin exists
admin_user = Users.get_super_admin_user() admin_user = await Users.get_super_admin_user()
creator_id = admin_user.id if admin_user else user.id creator_id = admin_user.id if admin_user else user.id
log.debug(f"Using creator ID {creator_id} for potential group creation.") log.debug(f"Using creator ID {creator_id} for potential group creation.")
@ -198,7 +200,7 @@ class OAuthManager:
user_ids=[], # Start with no users, user will be added later by subsequent logic user_ids=[], # Start with no users, user will be added later by subsequent logic
) )
# Use determined creator ID (admin or fallback to current user) # Use determined creator ID (admin or fallback to current user)
created_group = Groups.insert_new_group( created_group = await Groups.insert_new_group(
creator_id, new_group_form creator_id, new_group_form
) )
if created_group: if created_group:
@ -217,7 +219,7 @@ class OAuthManager:
# Refresh the list of all available groups if any were created # Refresh the list of all available groups if any were created
if groups_created: if groups_created:
all_available_groups = Groups.get_groups() all_available_groups = await Groups.get_groups()
log.debug("Refreshed list of all available groups after creation.") log.debug("Refreshed list of all available groups after creation.")
log.debug(f"Oauth Groups claim: {oauth_claim}") log.debug(f"Oauth Groups claim: {oauth_claim}")
@ -430,21 +432,21 @@ class OAuthManager:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# Check if the user exists # Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub) user = await Users.get_user_by_oauth_sub(provider_sub)
if not user: if not user:
# If the user does not exist, check if merging is enabled # If the user does not exist, check if merging is enabled
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL: if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
# Check if the user exists by email # Check if the user exists by email
user = Users.get_user_by_email(email) user = await Users.get_user_by_email(email)
if user: if user:
# Update the user with the new oauth sub # Update the user with the new oauth sub
Users.update_user_oauth_sub_by_id(user.id, provider_sub) await Users.update_user_oauth_sub_by_id(user.id, provider_sub)
if user: if user:
determined_role = self.get_user_role(user, user_data) determined_role = await self.get_user_role(user, user_data)
if user.role != determined_role: if user.role != determined_role:
Users.update_user_role_by_id(user.id, determined_role) await Users.update_user_role_by_id(user.id, determined_role)
# Update profile picture if enabled and different from current # Update profile picture if enabled and different from current
if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN: if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
@ -457,7 +459,7 @@ class OAuthManager:
new_picture_url, token.get("access_token") new_picture_url, token.get("access_token")
) )
if processed_picture_url != user.profile_image_url: if processed_picture_url != user.profile_image_url:
Users.update_user_profile_image_url_by_id( await Users.update_user_profile_image_url_by_id(
user.id, processed_picture_url user.id, processed_picture_url
) )
log.debug(f"Updated profile picture for user {user.email}") log.debug(f"Updated profile picture for user {user.email}")
@ -466,7 +468,7 @@ class OAuthManager:
# If the user does not exist, check if signups are enabled # If the user does not exist, check if signups are enabled
if auth_manager_config.ENABLE_OAUTH_SIGNUP: if auth_manager_config.ENABLE_OAUTH_SIGNUP:
# Check if an existing user with the same email already exists # Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(email) existing_user = await Users.get_user_by_email(email)
if existing_user: if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
@ -488,7 +490,7 @@ class OAuthManager:
log.warning("Username claim is missing, using email as name") log.warning("Username claim is missing, using email as name")
name = email name = email
role = self.get_user_role(None, user_data) role = await self.get_user_role(None, user_data)
user = await Auths.insert_new_auth( user = await Auths.insert_new_auth(
email=email, email=email,
@ -523,7 +525,7 @@ class OAuthManager:
) )
if auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT and user.role != "admin": if auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT and user.role != "admin":
self.update_user_groups( await self.update_user_groups(
user=user, user=user,
user_data=user_data, user_data=user_data,
default_permissions=request.app.state.config.USER_PERMISSIONS, default_permissions=request.app.state.config.USER_PERMISSIONS,

View file

@ -68,17 +68,17 @@ def replace_imports(content):
return content return content
def load_tool_module_by_id(tool_id, content=None): async def load_tool_module_by_id(tool_id, content=None):
if content is None: if content is None:
tool = Tools.get_tool_by_id(tool_id) tool = await Tools.get_tool_by_id(tool_id)
if not tool: if not tool:
raise Exception(f"Toolkit not found: {tool_id}") raise Exception(f"Toolkit not found: {tool_id}")
content = tool.content content = tool.content
content = replace_imports(content) content = replace_imports(content)
Tools.update_tool_by_id(tool_id, {"content": content}) await Tools.update_tool_by_id(tool_id, {"content": content})
else: else:
frontmatter = extract_frontmatter(content) frontmatter = extract_frontmatter(content)
# Install required packages found within the frontmatter # Install required packages found within the frontmatter
@ -241,7 +241,7 @@ def install_frontmatter_requirements(requirements: str):
log.info("No requirements found in frontmatter.") log.info("No requirements found in frontmatter.")
def install_tool_and_function_dependencies(): async def install_tool_and_function_dependencies():
""" """
Install all dependencies for all admin tools and active functions. Install all dependencies for all admin tools and active functions.
@ -249,8 +249,8 @@ def install_tool_and_function_dependencies():
and then installing them using pip. Duplicates or similar version specifications are and then installing them using pip. Duplicates or similar version specifications are
handled by pip as much as possible. handled by pip as much as possible.
""" """
function_list = Functions.get_functions(active_only=True) function_list = await Functions.get_functions(active_only=True)
tool_list = Tools.get_tools() tool_list = await Tools.get_tools()
all_dependencies = "" all_dependencies = ""
try: try:

View file

@ -68,13 +68,13 @@ def get_async_tool_function_and_apply_extra_params(
return new_function return new_function
def get_tools( async def get_tools(
request: Request, tool_ids: list[str], user: UserModel, extra_params: dict request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
) -> dict[str, dict]: ) -> dict[str, dict]:
tools_dict = {} tools_dict = {}
for tool_id in tool_ids: for tool_id in tool_ids:
tool = Tools.get_tool_by_id(tool_id) tool = await Tools.get_tool_by_id(tool_id)
if tool is None: if tool is None:
if tool_id.startswith("server:"): if tool_id.startswith("server:"):
server_idx = int(tool_id.split(":")[1]) server_idx = int(tool_id.split(":")[1])
@ -140,18 +140,18 @@ def get_tools(
else: else:
module = request.app.state.TOOLS.get(tool_id, None) module = request.app.state.TOOLS.get(tool_id, None)
if module is None: if module is None:
module, _ = load_tool_module_by_id(tool_id) module, _ = await load_tool_module_by_id(tool_id)
request.app.state.TOOLS[tool_id] = module request.app.state.TOOLS[tool_id] = module
extra_params["__id__"] = tool_id extra_params["__id__"] = tool_id
# Set valves for the tool # Set valves for the tool
if hasattr(module, "valves") and hasattr(module, "Valves"): if hasattr(module, "valves") and hasattr(module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id) or {} valves = await Tools.get_tool_valves_by_id(tool_id) or {}
module.valves = module.Valves(**valves) module.valves = module.Valves(**valves)
if hasattr(module, "UserValves"): if hasattr(module, "UserValves"):
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) **(await Tools.get_user_valves_by_id_and_user_id(tool_id, user.id))
) )
for spec in tool.specs: for spec in tool.specs: