mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
feat: model sync endpoint
This commit is contained in:
parent
49a6211d36
commit
c1e4139e5c
2 changed files with 62 additions and 0 deletions
|
|
@ -269,5 +269,49 @@ class ModelsTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]:
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
# Get existing models
|
||||||
|
existing_models = db.query(Model).all()
|
||||||
|
existing_ids = {model.id for model in existing_models}
|
||||||
|
|
||||||
|
# Prepare a set of new model IDs
|
||||||
|
new_model_ids = {model.id for model in models}
|
||||||
|
|
||||||
|
# Update or insert models
|
||||||
|
for model in models:
|
||||||
|
if model.id in existing_ids:
|
||||||
|
db.query(Model).filter_by(id=model.id).update(
|
||||||
|
{
|
||||||
|
**model.model_dump(),
|
||||||
|
"user_id": user_id,
|
||||||
|
"updated_at": int(time.time()),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_model = Model(
|
||||||
|
**{
|
||||||
|
**model.model_dump(),
|
||||||
|
"user_id": user_id,
|
||||||
|
"updated_at": int(time.time()),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
db.add(new_model)
|
||||||
|
|
||||||
|
# Remove models that are no longer present
|
||||||
|
for model in existing_models:
|
||||||
|
if model.id not in new_model_ids:
|
||||||
|
db.delete(model)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return [
|
||||||
|
ModelModel.model_validate(model) for model in db.query(Model).all()
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error syncing models for user {user_id}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
Models = ModelsTable()
|
Models = ModelsTable()
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ from open_webui.models.models import (
|
||||||
ModelUserResponse,
|
ModelUserResponse,
|
||||||
Models,
|
Models,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
|
||||||
|
|
@ -78,6 +80,22 @@ async def create_new_model(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# SyncModels
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
class SyncModelsForm(BaseModel):
|
||||||
|
models: list[ModelModel] = []
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sync", response_model=list[ModelModel])
|
||||||
|
async def sync_models(
|
||||||
|
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user)
|
||||||
|
):
|
||||||
|
return Models.sync_models(user.id, form_data.models)
|
||||||
|
|
||||||
|
|
||||||
###########################
|
###########################
|
||||||
# GetModelById
|
# GetModelById
|
||||||
###########################
|
###########################
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue