wip: models

This commit is contained in:
Timothy Jaeryang Baek 2025-08-14 16:34:49 +04:00
parent 30bd4a2910
commit 446699e415
8 changed files with 69 additions and 57 deletions

View file

@ -199,7 +199,7 @@ async def generate_function_chat_completion(
return params return params
model_id = form_data.get("model") model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id) model_info = await Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", {}) metadata = form_data.pop("metadata", {})

View file

@ -1281,7 +1281,7 @@ async def get_models(
filtered_models.append(model) filtered_models.append(model)
continue continue
model_info = Models.get_model_by_id(model["id"]) model_info = await Models.get_model_by_id(model["id"])
if model_info: if model_info:
if ( if (
(user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS) (user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
@ -1401,7 +1401,7 @@ async def chat_completion(
raise Exception("Model not found") raise Exception("Model not found")
model = request.app.state.MODELS[model_id] model = request.app.state.MODELS[model_id]
model_info = Models.get_model_by_id(model_id) model_info = await Models.get_model_by_id(model_id)
# Check if user has access to the model # Check if user has access to the model
if not BYPASS_MODEL_ACCESS_CONTROL and ( if not BYPASS_MODEL_ACCESS_CONTROL and (

View file

@ -143,7 +143,7 @@ class ModelForm(BaseModel):
class ModelsTable: class ModelsTable:
def insert_new_model( async def insert_new_model(
self, form_data: ModelForm, user_id: str self, form_data: ModelForm, user_id: str
) -> Optional[ModelModel]: ) -> Optional[ModelModel]:
model = ModelModel( model = ModelModel(
@ -157,9 +157,9 @@ class ModelsTable:
try: try:
async with get_db() as db: async with get_db() as db:
result = Model(**model.model_dump()) result = Model(**model.model_dump())
db.add(result) await db.add(result)
db.commit() await db.commit()
db.refresh(result) await db.refresh(result)
if result: if result:
return ModelModel.model_validate(result) return ModelModel.model_validate(result)
@ -169,14 +169,19 @@ class ModelsTable:
log.exception(f"Failed to insert a new model: {e}") log.exception(f"Failed to insert a new model: {e}")
return None return None
def get_all_models(self) -> list[ModelModel]: async def get_all_models(self) -> list[ModelModel]:
async with get_db() as db: async with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()] return [
ModelModel.model_validate(model)
for model in await db.query(Model).all()
]
async def get_models(self) -> list[ModelUserResponse]: async def get_models(self) -> list[ModelUserResponse]:
async with get_db() as db: async with get_db() as db:
models = [] models = []
for model in db.query(Model).filter(Model.base_model_id != None).all(): for model in (
await db.query(Model).filter(Model.base_model_id != None).all()
):
user = await Users.get_user_by_id(model.user_id) user = await Users.get_user_by_id(model.user_id)
models.append( models.append(
ModelUserResponse.model_validate( ModelUserResponse.model_validate(
@ -188,17 +193,19 @@ class ModelsTable:
) )
return models return models
def get_base_models(self) -> list[ModelModel]: async def get_base_models(self) -> list[ModelModel]:
async with get_db() as db: async with get_db() as db:
return [ return [
ModelModel.model_validate(model) ModelModel.model_validate(model)
for model in db.query(Model).filter(Model.base_model_id == None).all() for model in await db.query(Model)
.filter(Model.base_model_id == None)
.all()
] ]
async def get_models_by_user_id( async def get_models_by_user_id(
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write"
) -> list[ModelUserResponse]: ) -> list[ModelUserResponse]:
models = self.get_models() models = await self.get_models()
return [ return [
model model
for model in models for model in models
@ -206,74 +213,78 @@ class ModelsTable:
or await has_access(user_id, permission, model.access_control) or await has_access(user_id, permission, model.access_control)
] ]
def get_model_by_id(self, id: str) -> Optional[ModelModel]: async def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try: try:
async with get_db() as db: async with get_db() as db:
model = db.get(Model, id) model = await db.get(Model, id)
return ModelModel.model_validate(model) return ModelModel.model_validate(model)
except Exception: except Exception:
return None return None
def toggle_model_by_id(self, id: str) -> Optional[ModelModel]: async def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
async with get_db() as db: async with get_db() as db:
try: try:
is_active = db.query(Model).filter_by(id=id).first().is_active is_active = (await db.query(Model).filter_by(id=id).first()).is_active
db.query(Model).filter_by(id=id).update( await db.query(Model).filter_by(id=id).update(
{ {
"is_active": not is_active, "is_active": not is_active,
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
db.commit() await db.commit()
return self.get_model_by_id(id) return await self.get_model_by_id(id)
except Exception: except Exception:
return None return None
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: async def update_model_by_id(
self, id: str, model: ModelForm
) -> Optional[ModelModel]:
try: try:
async with get_db() as db: async with get_db() as db:
# update only the fields that are present in the model # update only the fields that are present in the model
result = ( result = (
db.query(Model) await db.query(Model)
.filter_by(id=id) .filter_by(id=id)
.update(model.model_dump(exclude={"id"})) .update(model.model_dump(exclude={"id"}))
) )
db.commit() await db.commit()
model = db.get(Model, id) model = await db.get(Model, id)
db.refresh(model) await db.refresh(model)
return ModelModel.model_validate(model) return ModelModel.model_validate(model)
except Exception as e: except Exception as e:
log.exception(f"Failed to update the model by id {id}: {e}") log.exception(f"Failed to update the model by id {id}: {e}")
return None return None
def delete_model_by_id(self, id: str) -> bool: async def delete_model_by_id(self, id: str) -> bool:
try: try:
async with get_db() as db: async with get_db() as db:
db.query(Model).filter_by(id=id).delete() await db.query(Model).filter_by(id=id).delete()
db.commit() await db.commit()
return True return True
except Exception: except Exception:
return False return False
def delete_all_models(self) -> bool: async def delete_all_models(self) -> bool:
try: try:
async with get_db() as db: async with get_db() as db:
db.query(Model).delete() await db.query(Model).delete()
db.commit() await db.commit()
return True return True
except Exception: except Exception:
return False return False
def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]: async def sync_models(
self, user_id: str, models: list[ModelModel]
) -> list[ModelModel]:
try: try:
async with get_db() as db: async with get_db() as db:
# Get existing models # Get existing models
existing_models = db.query(Model).all() existing_models = await db.query(Model).all()
existing_ids = {model.id for model in existing_models} existing_ids = {model.id for model in existing_models}
# Prepare a set of new model IDs # Prepare a set of new model IDs
@ -282,7 +293,7 @@ class ModelsTable:
# Update or insert models # Update or insert models
for model in models: for model in models:
if model.id in existing_ids: if model.id in existing_ids:
db.query(Model).filter_by(id=model.id).update( await db.query(Model).filter_by(id=model.id).update(
{ {
**model.model_dump(), **model.model_dump(),
"user_id": user_id, "user_id": user_id,
@ -297,17 +308,18 @@ class ModelsTable:
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
db.add(new_model) await db.add(new_model)
# Remove models that are no longer present # Remove models that are no longer present
for model in existing_models: for model in existing_models:
if model.id not in new_model_ids: if model.id not in new_model_ids:
db.delete(model) await db.delete(model)
db.commit() await db.commit()
return [ return [
ModelModel.model_validate(model) for model in db.query(Model).all() ModelModel.model_validate(model)
for model in await db.query(Model).all()
] ]
except Exception as e: except Exception as e:
log.exception(f"Error syncing models for user {user_id}: {e}") log.exception(f"Error syncing models for user {user_id}: {e}")

View file

@ -602,7 +602,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
log.info(f"Deleting knowledge base: {id} (name: {knowledge.name})") log.info(f"Deleting knowledge base: {id} (name: {knowledge.name})")
# Get all models # Get all models
models = Models.get_all_models() models = await Models.get_all_models()
log.info(f"Found {len(models)} models to check for knowledge base {id}") log.info(f"Found {len(models)} models to check for knowledge base {id}")
# Update models that reference this knowledge base # Update models that reference this knowledge base
@ -626,7 +626,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
access_control=model.access_control, access_control=model.access_control,
is_active=model.is_active, is_active=model.is_active,
) )
Models.update_model_by_id(model.id, model_form) await Models.update_model_by_id(model.id, model_form)
# Clean up vector DB # Clean up vector DB
try: try:

View file

@ -28,9 +28,9 @@ router = APIRouter()
@router.get("/", response_model=list[ModelUserResponse]) @router.get("/", response_model=list[ModelUserResponse])
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)): async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS: if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
return Models.get_models() return await Models.get_models()
else: else:
return Models.get_models_by_user_id(user.id) return await Models.get_models_by_user_id(user.id)
########################### ###########################
@ -40,7 +40,7 @@ async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
@router.get("/base", response_model=list[ModelResponse]) @router.get("/base", response_model=list[ModelResponse])
async def get_base_models(user=Depends(get_admin_user)): async def get_base_models(user=Depends(get_admin_user)):
return Models.get_base_models() return await Models.get_base_models()
############################ ############################
@ -62,7 +62,7 @@ async def create_new_model(
detail=ERROR_MESSAGES.UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED,
) )
model = Models.get_model_by_id(form_data.id) model = await Models.get_model_by_id(form_data.id)
if model: if model:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -70,7 +70,7 @@ async def create_new_model(
) )
else: else:
model = Models.insert_new_model(form_data, user.id) model = await Models.insert_new_model(form_data, user.id)
if model: if model:
return model return model
else: else:
@ -87,7 +87,7 @@ async def create_new_model(
@router.get("/export", response_model=list[ModelModel]) @router.get("/export", response_model=list[ModelModel])
async def export_models(user=Depends(get_admin_user)): async def export_models(user=Depends(get_admin_user)):
return Models.get_models() return await Models.get_models()
############################ ############################
@ -103,7 +103,7 @@ class SyncModelsForm(BaseModel):
async def sync_models( async def sync_models(
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user) request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user)
): ):
return Models.sync_models(user.id, form_data.models) return await Models.sync_models(user.id, form_data.models)
########################### ###########################
@ -114,7 +114,7 @@ async def sync_models(
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id # Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
@router.get("/model", response_model=Optional[ModelResponse]) @router.get("/model", response_model=Optional[ModelResponse])
async def get_model_by_id(id: str, user=Depends(get_verified_user)): async def get_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id) model = await Models.get_model_by_id(id)
if model: if model:
if ( if (
(user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS) (user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
@ -136,14 +136,14 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/model/toggle", response_model=Optional[ModelResponse]) @router.post("/model/toggle", response_model=Optional[ModelResponse])
async def toggle_model_by_id(id: str, user=Depends(get_verified_user)): async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id) model = await Models.get_model_by_id(id)
if model: if model:
if ( if (
user.role == "admin" user.role == "admin"
or model.user_id == user.id or model.user_id == user.id
or await has_access(user.id, "write", model.access_control) or await has_access(user.id, "write", model.access_control)
): ):
model = Models.toggle_model_by_id(id) model = await Models.toggle_model_by_id(id)
if model: if model:
return model return model
@ -175,7 +175,7 @@ async def update_model_by_id(
form_data: ModelForm, form_data: ModelForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
model = Models.get_model_by_id(id) model = await Models.get_model_by_id(id)
if not model: if not model:
raise HTTPException( raise HTTPException(
@ -193,7 +193,7 @@ async def update_model_by_id(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
model = Models.update_model_by_id(id, form_data) model = await Models.update_model_by_id(id, form_data)
return model return model

View file

@ -441,7 +441,7 @@ async def get_filtered_models(models, user):
# Filter models based on user access control # Filter models based on user access control
filtered_models = [] filtered_models = []
for model in models.get("models", []): for model in models.get("models", []):
model_info = Models.get_model_by_id(model["model"]) model_info = await Models.get_model_by_id(model["model"])
if model_info: if model_info:
if user.id == model_info.user_id or ( if user.id == model_info.user_id or (
await has_access( await has_access(
@ -1320,7 +1320,7 @@ async def generate_chat_completion(
del payload["metadata"] del payload["metadata"]
model_id = payload["model"] model_id = payload["model"]
model_info = Models.get_model_by_id(model_id) model_info = await Models.get_model_by_id(model_id)
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:

View file

@ -383,7 +383,7 @@ async def get_filtered_models(models, user):
# Filter models based on user access control # Filter models based on user access control
filtered_models = [] filtered_models = []
for model in models.get("data", []): for model in models.get("data", []):
model_info = Models.get_model_by_id(model["id"]) model_info = await Models.get_model_by_id(model["id"])
if model_info: if model_info:
if user.id == model_info.user_id or ( if user.id == model_info.user_id or (
await has_access( await has_access(
@ -738,7 +738,7 @@ async def generate_chat_completion(
metadata = payload.pop("metadata", None) metadata = payload.pop("metadata", None)
model_id = form_data.get("model") model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id) model_info = await Models.get_model_by_id(model_id)
# Check model info and override the payload # Check model info and override the payload
if model_info: if model_info:

View file

@ -145,7 +145,7 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
for function in Functions.get_functions_by_type("filter", active_only=True) for function in Functions.get_functions_by_type("filter", active_only=True)
] ]
custom_models = Models.get_all_models() custom_models = await Models.get_all_models()
for custom_model in custom_models: for custom_model in custom_models:
if custom_model.base_model_id is None: if custom_model.base_model_id is None:
# Applied directly to a base model # Applied directly to a base model