From 446699e4150de443a0bdef82d3b80d1f58b9810d Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 14 Aug 2025 16:34:49 +0400 Subject: [PATCH] wip: models --- backend/open_webui/functions.py | 2 +- backend/open_webui/main.py | 4 +- backend/open_webui/models/models.py | 82 ++++++++++++++----------- backend/open_webui/routers/knowledge.py | 4 +- backend/open_webui/routers/models.py | 24 ++++---- backend/open_webui/routers/ollama.py | 4 +- backend/open_webui/routers/openai.py | 4 +- backend/open_webui/utils/models.py | 2 +- 8 files changed, 69 insertions(+), 57 deletions(-) diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 422d121382..b2c0a14352 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -199,7 +199,7 @@ async def generate_function_chat_completion( return params model_id = form_data.get("model") - model_info = Models.get_model_by_id(model_id) + model_info = await Models.get_model_by_id(model_id) metadata = form_data.pop("metadata", {}) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index d3b302ec5c..423f604824 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1281,7 +1281,7 @@ async def get_models( filtered_models.append(model) continue - model_info = Models.get_model_by_id(model["id"]) + model_info = await Models.get_model_by_id(model["id"]) if model_info: if ( (user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS) @@ -1401,7 +1401,7 @@ async def chat_completion( raise Exception("Model not found") model = request.app.state.MODELS[model_id] - model_info = Models.get_model_by_id(model_id) + model_info = await Models.get_model_by_id(model_id) # Check if user has access to the model if not BYPASS_MODEL_ACCESS_CONTROL and ( diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py index 28fd9bd51b..d7cfd01a17 100755 --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -143,7 +143,7 @@ class ModelForm(BaseModel): class ModelsTable: - def insert_new_model( + async def insert_new_model( self, form_data: ModelForm, user_id: str ) -> Optional[ModelModel]: model = ModelModel( @@ -157,9 +157,9 @@ class ModelsTable: try: async with get_db() as db: result = Model(**model.model_dump()) - db.add(result) - db.commit() - db.refresh(result) + await db.add(result) + await db.commit() + await db.refresh(result) if result: return ModelModel.model_validate(result) @@ -169,14 +169,19 @@ class ModelsTable: log.exception(f"Failed to insert a new model: {e}") return None - def get_all_models(self) -> list[ModelModel]: + async def get_all_models(self) -> list[ModelModel]: async with get_db() as db: - return [ModelModel.model_validate(model) for model in db.query(Model).all()] + return [ + ModelModel.model_validate(model) + for model in await db.query(Model).all() + ] async def get_models(self) -> list[ModelUserResponse]: async with get_db() as db: models = [] - for model in db.query(Model).filter(Model.base_model_id != None).all(): + for model in ( + await db.query(Model).filter(Model.base_model_id != None).all() + ): user = await Users.get_user_by_id(model.user_id) models.append( ModelUserResponse.model_validate( @@ -188,17 +193,19 @@ class ModelsTable: ) return models - def get_base_models(self) -> list[ModelModel]: + async def get_base_models(self) -> list[ModelModel]: async with get_db() as db: return [ ModelModel.model_validate(model) - for model in db.query(Model).filter(Model.base_model_id == None).all() + for model in await db.query(Model) + .filter(Model.base_model_id == None) + .all() ] async def get_models_by_user_id( self, user_id: str, permission: str = "write" ) -> list[ModelUserResponse]: - models = self.get_models() + models = await self.get_models() return [ model for model in models @@ -206,74 +213,78 @@ class ModelsTable: or await has_access(user_id, permission, model.access_control) ] - def get_model_by_id(self, id: str) -> Optional[ModelModel]: + async def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: async with get_db() as db: - model = db.get(Model, id) + model = await db.get(Model, id) return ModelModel.model_validate(model) except Exception: return None - def toggle_model_by_id(self, id: str) -> Optional[ModelModel]: + async def toggle_model_by_id(self, id: str) -> Optional[ModelModel]: async with get_db() as db: try: - is_active = db.query(Model).filter_by(id=id).first().is_active + is_active = (await db.query(Model).filter_by(id=id).first()).is_active - db.query(Model).filter_by(id=id).update( + await db.query(Model).filter_by(id=id).update( { "is_active": not is_active, "updated_at": int(time.time()), } ) - db.commit() + await db.commit() - return self.get_model_by_id(id) + return await self.get_model_by_id(id) except Exception: return None - def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: + async def update_model_by_id( + self, id: str, model: ModelForm + ) -> Optional[ModelModel]: try: async with get_db() as db: # update only the fields that are present in the model result = ( - db.query(Model) + await db.query(Model) .filter_by(id=id) .update(model.model_dump(exclude={"id"})) ) - db.commit() + await db.commit() - model = db.get(Model, id) - db.refresh(model) + model = await db.get(Model, id) + await db.refresh(model) return ModelModel.model_validate(model) except Exception as e: log.exception(f"Failed to update the model by id {id}: {e}") return None - def delete_model_by_id(self, id: str) -> bool: + async def delete_model_by_id(self, id: str) -> bool: try: async with get_db() as db: - db.query(Model).filter_by(id=id).delete() - db.commit() + await db.query(Model).filter_by(id=id).delete() + await db.commit() return True except Exception: return False - def delete_all_models(self) -> bool: + async def delete_all_models(self) -> bool: try: async with get_db() as db: - db.query(Model).delete() - db.commit() + await db.query(Model).delete() + await db.commit() return True except Exception: return False - def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]: + async def sync_models( + self, user_id: str, models: list[ModelModel] + ) -> list[ModelModel]: try: async with get_db() as db: # Get existing models - existing_models = db.query(Model).all() + existing_models = await db.query(Model).all() existing_ids = {model.id for model in existing_models} # Prepare a set of new model IDs @@ -282,7 +293,7 @@ class ModelsTable: # Update or insert models for model in models: if model.id in existing_ids: - db.query(Model).filter_by(id=model.id).update( + await db.query(Model).filter_by(id=model.id).update( { **model.model_dump(), "user_id": user_id, @@ -297,17 +308,18 @@ class ModelsTable: "updated_at": int(time.time()), } ) - db.add(new_model) + await db.add(new_model) # Remove models that are no longer present for model in existing_models: if model.id not in new_model_ids: - db.delete(model) + await db.delete(model) - db.commit() + await db.commit() return [ - ModelModel.model_validate(model) for model in db.query(Model).all() + ModelModel.model_validate(model) + for model in await db.query(Model).all() ] except Exception as e: log.exception(f"Error syncing models for user {user_id}: {e}") diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 555581a95d..5769c664d3 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -602,7 +602,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): log.info(f"Deleting knowledge base: {id} (name: {knowledge.name})") # Get all models - models = Models.get_all_models() + models = await Models.get_all_models() log.info(f"Found {len(models)} models to check for knowledge base {id}") # Update models that reference this knowledge base @@ -626,7 +626,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): access_control=model.access_control, is_active=model.is_active, ) - Models.update_model_by_id(model.id, model_form) + await Models.update_model_by_id(model.id, model_form) # Clean up vector DB try: diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index 7ae7f07301..c11772fe1c 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -28,9 +28,9 @@ router = APIRouter() @router.get("/", response_model=list[ModelUserResponse]) async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)): if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS: - return Models.get_models() + return await Models.get_models() else: - return Models.get_models_by_user_id(user.id) + return await Models.get_models_by_user_id(user.id) ########################### @@ -40,7 +40,7 @@ async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)): @router.get("/base", response_model=list[ModelResponse]) async def get_base_models(user=Depends(get_admin_user)): - return Models.get_base_models() + return await Models.get_base_models() ############################ @@ -62,7 +62,7 @@ async def create_new_model( detail=ERROR_MESSAGES.UNAUTHORIZED, ) - model = Models.get_model_by_id(form_data.id) + model = await Models.get_model_by_id(form_data.id) if model: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -70,7 +70,7 @@ async def create_new_model( ) else: - model = Models.insert_new_model(form_data, user.id) + model = await Models.insert_new_model(form_data, user.id) if model: return model else: @@ -87,7 +87,7 @@ async def create_new_model( @router.get("/export", response_model=list[ModelModel]) async def export_models(user=Depends(get_admin_user)): - return Models.get_models() + return await Models.get_models() ############################ @@ -103,7 +103,7 @@ class SyncModelsForm(BaseModel): async def sync_models( request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user) ): - return Models.sync_models(user.id, form_data.models) + return await Models.sync_models(user.id, form_data.models) ########################### @@ -114,7 +114,7 @@ async def sync_models( # Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id @router.get("/model", response_model=Optional[ModelResponse]) async def get_model_by_id(id: str, user=Depends(get_verified_user)): - model = Models.get_model_by_id(id) + model = await Models.get_model_by_id(id) if model: if ( (user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS) @@ -136,14 +136,14 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)): @router.post("/model/toggle", response_model=Optional[ModelResponse]) async def toggle_model_by_id(id: str, user=Depends(get_verified_user)): - model = Models.get_model_by_id(id) + model = await Models.get_model_by_id(id) if model: if ( user.role == "admin" or model.user_id == user.id or await has_access(user.id, "write", model.access_control) ): - model = Models.toggle_model_by_id(id) + model = await Models.toggle_model_by_id(id) if model: return model @@ -175,7 +175,7 @@ async def update_model_by_id( form_data: ModelForm, user=Depends(get_verified_user), ): - model = Models.get_model_by_id(id) + model = await Models.get_model_by_id(id) if not model: raise HTTPException( @@ -193,7 +193,7 @@ async def update_model_by_id( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - model = Models.update_model_by_id(id, form_data) + model = await Models.update_model_by_id(id, form_data) return model diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 7c0bdc3f4c..56096579dd 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -441,7 +441,7 @@ async def get_filtered_models(models, user): # Filter models based on user access control filtered_models = [] for model in models.get("models", []): - model_info = Models.get_model_by_id(model["model"]) + model_info = await Models.get_model_by_id(model["model"]) if model_info: if user.id == model_info.user_id or ( await has_access( @@ -1320,7 +1320,7 @@ async def generate_chat_completion( del payload["metadata"] model_id = payload["model"] - model_info = Models.get_model_by_id(model_id) + model_info = await Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index c643591669..5f867c6833 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -383,7 +383,7 @@ async def get_filtered_models(models, user): # Filter models based on user access control filtered_models = [] for model in models.get("data", []): - model_info = Models.get_model_by_id(model["id"]) + model_info = await Models.get_model_by_id(model["id"]) if model_info: if user.id == model_info.user_id or ( await has_access( @@ -738,7 +738,7 @@ async def generate_chat_completion( metadata = payload.pop("metadata", None) model_id = form_data.get("model") - model_info = Models.get_model_by_id(model_id) + model_info = await Models.get_model_by_id(model_id) # Check model info and override the payload if model_info: diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index 27bf5d2f2d..ae637615c9 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -145,7 +145,7 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None) for function in Functions.get_functions_by_type("filter", active_only=True) ] - custom_models = Models.get_all_models() + custom_models = await Models.get_all_models() for custom_model in custom_models: if custom_model.base_model_id is None: # Applied directly to a base model