mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-11 20:05:19 +00:00
feat: model sync endpoint
This commit is contained in:
parent
49a6211d36
commit
c1e4139e5c
2 changed files with 62 additions and 0 deletions
|
|
@ -269,5 +269,49 @@ class ModelsTable:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
# Get existing models
|
||||
existing_models = db.query(Model).all()
|
||||
existing_ids = {model.id for model in existing_models}
|
||||
|
||||
# Prepare a set of new model IDs
|
||||
new_model_ids = {model.id for model in models}
|
||||
|
||||
# Update or insert models
|
||||
for model in models:
|
||||
if model.id in existing_ids:
|
||||
db.query(Model).filter_by(id=model.id).update(
|
||||
{
|
||||
**model.model_dump(),
|
||||
"user_id": user_id,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
else:
|
||||
new_model = Model(
|
||||
**{
|
||||
**model.model_dump(),
|
||||
"user_id": user_id,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.add(new_model)
|
||||
|
||||
# Remove models that are no longer present
|
||||
for model in existing_models:
|
||||
if model.id not in new_model_ids:
|
||||
db.delete(model)
|
||||
|
||||
db.commit()
|
||||
|
||||
return [
|
||||
ModelModel.model_validate(model) for model in db.query(Model).all()
|
||||
]
|
||||
except Exception as e:
|
||||
log.exception(f"Error syncing models for user {user_id}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
Models = ModelsTable()
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from open_webui.models.models import (
|
|||
ModelUserResponse,
|
||||
Models,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
|
||||
|
|
@ -78,6 +80,22 @@ async def create_new_model(
|
|||
)
|
||||
|
||||
|
||||
############################
|
||||
# SyncModels
|
||||
############################
|
||||
|
||||
|
||||
class SyncModelsForm(BaseModel):
|
||||
models: list[ModelModel] = []
|
||||
|
||||
|
||||
@router.post("/sync", response_model=list[ModelModel])
|
||||
async def sync_models(
|
||||
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user)
|
||||
):
|
||||
return Models.sync_models(user.id, form_data.models)
|
||||
|
||||
|
||||
###########################
|
||||
# GetModelById
|
||||
###########################
|
||||
|
|
|
|||
Loading…
Reference in a new issue