open-webui/backend/open_webui/routers/models.py

418 lines
12 KiB
Python
Raw Normal View History

2024-08-27 22:10:27 +00:00
from typing import Optional
2025-09-17 05:49:44 +00:00
import io
import base64
import json
import asyncio
import logging
2024-05-24 07:26:00 +00:00
2025-11-25 11:32:27 +00:00
from open_webui.models.groups import Groups
2024-12-10 08:54:13 +00:00
from open_webui.models.models import (
ModelForm,
ModelModel,
ModelResponse,
2025-11-23 04:20:51 +00:00
ModelListResponse,
Models,
)
2025-07-28 09:06:05 +00:00
from pydantic import BaseModel
from open_webui.constants import ERROR_MESSAGES
from fastapi import (
APIRouter,
Depends,
HTTPException,
Request,
status,
Response,
)
2025-09-17 05:49:44 +00:00
from fastapi.responses import FileResponse, StreamingResponse
2024-11-15 09:29:07 +00:00
2024-12-09 00:01:56 +00:00
from open_webui.utils.auth import get_admin_user, get_verified_user
2024-11-17 11:04:31 +00:00
from open_webui.utils.access_control import has_access, has_permission
2025-09-17 05:49:44 +00:00
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR
2024-05-24 07:26:00 +00:00
log = logging.getLogger(__name__)
2024-05-24 07:26:00 +00:00
router = APIRouter()
2024-11-15 09:29:07 +00:00
2025-11-19 08:18:16 +00:00
def is_valid_model_id(model_id: str) -> bool:
2025-10-08 21:54:06 +00:00
return model_id and len(model_id) <= 256
2024-05-24 07:26:00 +00:00
###########################
2024-11-15 09:29:07 +00:00
# GetModels
2024-05-24 07:26:00 +00:00
###########################
2025-11-23 04:20:51 +00:00
PAGE_ITEM_COUNT = 30
2025-11-02 23:41:57 +00:00
@router.get(
2025-11-23 04:20:51 +00:00
"/list", response_model=ModelListResponse
2025-11-02 23:41:57 +00:00
) # do NOT use "/" as path, conflicts with main.py
2025-11-23 04:20:51 +00:00
async def get_models(
query: Optional[str] = None,
view_option: Optional[str] = None,
tag: Optional[str] = None,
order_by: Optional[str] = None,
direction: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_verified_user),
):
limit = PAGE_ITEM_COUNT
page = max(1, page)
skip = (page - 1) * limit
filter = {}
if query:
filter["query"] = query
if view_option:
filter["view_option"] = view_option
if tag:
filter["tag"] = tag
if order_by:
filter["order_by"] = order_by
if direction:
filter["direction"] = direction
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
2025-11-25 11:32:27 +00:00
groups = Groups.get_groups_by_member_id(user.id)
if groups:
filter["group_ids"] = [group.id for group in groups]
2025-11-23 04:20:51 +00:00
filter["user_id"] = user.id
return Models.search_models(user.id, filter=filter, skip=skip, limit=limit)
2024-05-24 07:26:00 +00:00
2025-03-04 05:27:48 +00:00
###########################
# GetBaseModels
###########################
@router.get("/base", response_model=list[ModelResponse])
async def get_base_models(user=Depends(get_admin_user)):
return Models.get_base_models()
2025-11-23 04:20:51 +00:00
###########################
# GetModelTags
###########################
@router.get("/tags", response_model=list[str])
async def get_model_tags(user=Depends(get_verified_user)):
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
models = Models.get_models()
else:
models = Models.get_models_by_user_id(user.id)
tags_set = set()
for model in models:
2025-11-23 21:08:03 +00:00
if model.meta:
meta = model.meta.model_dump()
for tag in meta.get("tags", []):
2025-11-23 04:20:51 +00:00
tags_set.add((tag.get("name")))
tags = [tag for tag in tags_set]
tags.sort()
return tags
2024-05-24 07:26:00 +00:00
############################
2024-11-15 09:29:07 +00:00
# CreateNewModel
2024-05-24 07:26:00 +00:00
############################
2024-11-15 09:29:07 +00:00
@router.post("/create", response_model=Optional[ModelModel])
async def create_new_model(
2024-11-17 11:04:31 +00:00
request: Request,
form_data: ModelForm,
2024-11-15 09:29:07 +00:00
user=Depends(get_verified_user),
2024-05-25 05:21:57 +00:00
):
2024-11-17 11:04:31 +00:00
if user.role != "admin" and not has_permission(
user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
2024-11-15 09:29:07 +00:00
model = Models.get_model_by_id(form_data.id)
if model:
2024-05-24 07:26:00 +00:00
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
2024-05-25 05:21:57 +00:00
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
2024-05-24 07:26:00 +00:00
)
2024-11-15 09:29:07 +00:00
2025-11-19 08:18:16 +00:00
if not is_valid_model_id(form_data.id):
2025-10-08 21:54:06 +00:00
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.MODEL_ID_TOO_LONG,
)
2024-05-25 05:21:57 +00:00
else:
model = Models.insert_new_model(form_data, user.id)
2024-05-25 05:21:57 +00:00
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
2024-05-24 07:26:00 +00:00
2025-07-28 09:12:38 +00:00
############################
# ExportModels
############################
@router.get("/export", response_model=list[ModelModel])
async def export_models(request: Request, user=Depends(get_verified_user)):
if user.role != "admin" and not has_permission(
user.id, "workspace.models_export", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
return Models.get_models()
else:
return Models.get_models_by_user_id(user.id)
2025-07-28 09:12:38 +00:00
############################
# ImportModels
############################
class ModelsImportForm(BaseModel):
models: list[dict]
@router.post("/import", response_model=bool)
async def import_models(
request: Request,
user=Depends(get_verified_user),
form_data: ModelsImportForm = (...),
):
if user.role != "admin" and not has_permission(
user.id, "workspace.models_import", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
try:
data = form_data.models
if isinstance(data, list):
for model_data in data:
# Here, you can add logic to validate model_data if needed
model_id = model_data.get("id")
2025-10-08 21:54:06 +00:00
2025-11-19 08:18:16 +00:00
if model_id and is_valid_model_id(model_id):
existing_model = Models.get_model_by_id(model_id)
if existing_model:
# Update existing model
model_data["meta"] = model_data.get("meta", {})
model_data["params"] = model_data.get("params", {})
updated_model = ModelForm(
**{**existing_model.model_dump(), **model_data}
)
Models.update_model_by_id(model_id, updated_model)
else:
# Insert new model
model_data["meta"] = model_data.get("meta", {})
model_data["params"] = model_data.get("params", {})
new_model = ModelForm(**model_data)
Models.insert_new_model(user_id=user.id, form_data=new_model)
return True
else:
raise HTTPException(status_code=400, detail="Invalid JSON format")
except Exception as e:
log.exception(e)
raise HTTPException(status_code=500, detail=str(e))
2025-07-28 09:06:05 +00:00
############################
# SyncModels
############################
class SyncModelsForm(BaseModel):
models: list[ModelModel] = []
@router.post("/sync", response_model=list[ModelModel])
async def sync_models(
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user)
):
return Models.sync_models(user.id, form_data.models)
2024-11-15 09:29:07 +00:00
###########################
# GetModelById
###########################
2025-11-19 08:18:16 +00:00
class ModelIdForm(BaseModel):
id: str
2024-11-20 02:45:26 +00:00
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
@router.get("/model", response_model=Optional[ModelResponse])
2024-11-15 09:29:07 +00:00
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if model:
if (
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
2024-11-15 09:29:07 +00:00
or model.user_id == user.id
or has_access(user.id, "read", model.access_control)
):
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
2025-09-17 05:49:44 +00:00
###########################
# GetModelById
###########################
@router.get("/model/profile/image")
async def get_model_profile_image(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if model:
if model.meta.profile_image_url:
if model.meta.profile_image_url.startswith("http"):
return Response(
status_code=status.HTTP_302_FOUND,
headers={"Location": model.meta.profile_image_url},
)
elif model.meta.profile_image_url.startswith("data:image"):
try:
header, base64_data = model.meta.profile_image_url.split(",", 1)
image_data = base64.b64decode(base64_data)
image_buffer = io.BytesIO(image_data)
return StreamingResponse(
image_buffer,
media_type="image/png",
headers={"Content-Disposition": "inline; filename=image.png"},
)
except Exception as e:
pass
2025-11-21 01:43:59 +00:00
2025-09-17 05:49:44 +00:00
return FileResponse(f"{STATIC_DIR}/favicon.png")
else:
return FileResponse(f"{STATIC_DIR}/favicon.png")
2024-11-16 02:21:41 +00:00
############################
2025-07-28 09:12:38 +00:00
# ToggleModelById
2024-11-16 02:21:41 +00:00
############################
2024-11-20 02:45:26 +00:00
@router.post("/model/toggle", response_model=Optional[ModelResponse])
2024-11-16 02:21:41 +00:00
async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if model:
if (
user.role == "admin"
or model.user_id == user.id
or has_access(user.id, "write", model.access_control)
):
model = Models.toggle_model_by_id(id)
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
2024-05-24 07:26:00 +00:00
############################
# UpdateModelById
############################
2024-11-20 02:45:26 +00:00
@router.post("/model/update", response_model=Optional[ModelModel])
2024-05-24 07:26:00 +00:00
async def update_model_by_id(
form_data: ModelForm,
2024-11-15 09:29:07 +00:00
user=Depends(get_verified_user),
2024-05-24 07:26:00 +00:00
):
2025-11-19 08:18:16 +00:00
model = Models.get_model_by_id(form_data.id)
2024-11-15 09:29:07 +00:00
if not model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
2025-01-23 18:40:49 +00:00
if (
model.user_id != user.id
and not has_access(user.id, "write", model.access_control)
and user.role != "admin"
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
2025-11-21 06:48:06 +00:00
model = Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump()))
2024-11-15 09:29:07 +00:00
return model
2024-05-24 07:26:00 +00:00
############################
# DeleteModelById
############################
2025-11-19 08:18:16 +00:00
@router.post("/model/delete", response_model=bool)
async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_user)):
model = Models.get_model_by_id(form_data.id)
2024-11-15 09:29:07 +00:00
if not model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
2025-01-30 05:31:18 +00:00
user.role != "admin"
2025-01-30 05:32:07 +00:00
and model.user_id != user.id
and not has_access(user.id, "write", model.access_control)
2025-01-30 05:31:18 +00:00
):
2024-11-15 09:29:07 +00:00
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
2025-11-19 08:18:16 +00:00
result = Models.delete_model_by_id(form_data.id)
2024-05-24 07:26:00 +00:00
return result
2024-11-19 19:03:36 +00:00
@router.delete("/delete/all", response_model=bool)
async def delete_all_models(user=Depends(get_admin_user)):
result = Models.delete_all_models()
return result