diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py index 7df8d8656b..1a29b86eae 100755 --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -269,5 +269,49 @@ class ModelsTable: except Exception: return False + def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]: + try: + with get_db() as db: + # Get existing models + existing_models = db.query(Model).all() + existing_ids = {model.id for model in existing_models} + + # Prepare a set of new model IDs + new_model_ids = {model.id for model in models} + + # Update or insert models + for model in models: + if model.id in existing_ids: + db.query(Model).filter_by(id=model.id).update( + { + **model.model_dump(), + "user_id": user_id, + "updated_at": int(time.time()), + } + ) + else: + new_model = Model( + **{ + **model.model_dump(), + "user_id": user_id, + "updated_at": int(time.time()), + } + ) + db.add(new_model) + + # Remove models that are no longer present + for model in existing_models: + if model.id not in new_model_ids: + db.delete(model) + + db.commit() + + return [ + ModelModel.model_validate(model) for model in db.query(Model).all() + ] + except Exception as e: + log.exception(f"Error syncing models for user {user_id}: {e}") + return [] + Models = ModelsTable() diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index 0cf3308f19..8bf87ba8a6 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -7,6 +7,8 @@ from open_webui.models.models import ( ModelUserResponse, Models, ) + +from pydantic import BaseModel from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status @@ -78,6 +80,22 @@ async def create_new_model( ) +############################ +# SyncModels +############################ + + +class SyncModelsForm(BaseModel): + models: list[ModelModel] = [] + + +@router.post("/sync", response_model=list[ModelModel]) +async def sync_models( + request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user) +): + return Models.sync_models(user.id, form_data.models) + + ########################### # GetModelById ###########################