diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index 05d7c68006..2ced14de4e 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -1,6 +1,9 @@ from typing import Optional import io import base64 +import json +import asyncio +import logging from open_webui.models.models import ( ModelForm, @@ -12,7 +15,16 @@ from open_webui.models.models import ( from pydantic import BaseModel from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, HTTPException, Request, status, Response +from fastapi import ( + APIRouter, + Depends, + HTTPException, + Request, + status, + Response, + UploadFile, + File, +) from fastapi.responses import FileResponse, StreamingResponse @@ -20,6 +32,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR +log = logging.getLogger(__name__) + router = APIRouter() @@ -93,6 +107,46 @@ async def export_models(user=Depends(get_admin_user)): return Models.get_models() +############################ +# ImportModels +############################ + + +@router.post("/import", response_model=bool) +async def import_models( + user: str = Depends(get_admin_user), file: UploadFile = File(...) +): + try: + data = json.loads(await file.read()) + if isinstance(data, list): + for model_data in data: + # Here, you can add logic to validate model_data if needed + model_id = model_data.get("id") + if model_id: + existing_model = Models.get_model_by_id(model_id) + if existing_model: + # Update existing model + model_data["meta"] = model_data.get("meta", {}) + model_data["params"] = model_data.get("params", {}) + + updated_model = ModelForm( + **{**existing_model.model_dump(), **model_data} + ) + Models.update_model_by_id(model_id, updated_model) + else: + # Insert new model + model_data["meta"] = model_data.get("meta", {}) + model_data["params"] = model_data.get("params", {}) + new_model = ModelForm(**model_data) + Models.insert_new_model(user_id=user.id, form_data=new_model) + return True + else: + raise HTTPException(status_code=400, detail="Invalid JSON format") + except Exception as e: + log.exception(e) + raise HTTPException(status_code=500, detail=str(e)) + + ############################ # SyncModels ############################ diff --git a/src/lib/apis/models/index.ts b/src/lib/apis/models/index.ts index 3e6e0d0c0b..63ca95f200 100644 --- a/src/lib/apis/models/index.ts +++ b/src/lib/apis/models/index.ts @@ -31,6 +31,36 @@ export const getModels = async (token: string = '') => { return res; }; +export const importModels = async (token: string, file: File) => { + let error = null; + + const formData = new FormData(); + formData.append('file', file); + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/import`, { + method: 'POST', + headers: { + authorization: `Bearer ${token}` + }, + body: formData + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getBaseModels = async (token: string = '') => { let error = null; diff --git a/src/lib/components/admin/Settings/Models.svelte b/src/lib/components/admin/Settings/Models.svelte index f3df30377f..211d60efa5 100644 --- a/src/lib/components/admin/Settings/Models.svelte +++ b/src/lib/components/admin/Settings/Models.svelte @@ -12,7 +12,8 @@ deleteAllModels, getBaseModels, toggleModelById, - updateModelById + updateModelById, + importModels } from '$lib/apis/models'; import { copyToClipboard } from '$lib/utils'; import { page } from '$app/stores'; @@ -40,6 +41,7 @@ let shiftKey = false; +let modelsImportInProgress = false; let importFiles; let modelsImportInputElement: HTMLInputElement; @@ -463,48 +465,32 @@ type="file" accept=".json" hidden - on:change={() => { - console.log(importFiles); + on:change={async () => { + if (importFiles.length > 0) { + modelsImportInProgress = true; + const res = await importModels(localStorage.token, importFiles[0]); + modelsImportInProgress = false; - let reader = new FileReader(); - reader.onload = async (event) => { - let savedModels = JSON.parse(event.target.result); - console.log(savedModels); - - for (const model of savedModels) { - if (Object.keys(model).includes('base_model_id')) { - if (model.base_model_id === null) { - upsertModelHandler(model); - } - } else { - if (model?.info ?? false) { - if (model.info.base_model_id === null) { - upsertModelHandler(model.info); - } - } - } + if (res) { + toast.success($i18n.t('Models imported successfully')); + await init(); + } else { + toast.error($i18n.t('Failed to import models')); } - - await _models.set( - await getModels( - localStorage.token, - $config?.features?.enable_direct_connections && - ($settings?.directConnections ?? null) - ) - ); - init(); - }; - - reader.readAsText(importFiles[0]); + } }} />