From b2034861aec51680c809b56920bb0029df872d06 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sat, 22 Nov 2025 23:20:51 -0500 Subject: [PATCH] refac: models workspace optimization --- backend/open_webui/models/models.py | 87 +++++++++++- backend/open_webui/routers/models.py | 66 ++++++++- src/lib/apis/models/index.ts | 63 ++++++++- src/lib/components/workspace/Models.svelte | 133 ++++++++++-------- .../workspace/common/TagSelector.svelte | 13 +- 5 files changed, 292 insertions(+), 70 deletions(-) diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py index f5964c0579..f9390b405d 100755 --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -6,12 +6,12 @@ from open_webui.internal.db import Base, JSONField, get_db from open_webui.env import SRC_LOG_LEVELS from open_webui.models.groups import Groups -from open_webui.models.users import Users, UserResponse +from open_webui.models.users import User, UserModel, Users, UserResponse from pydantic import BaseModel, ConfigDict -from sqlalchemy import or_, and_, func +from sqlalchemy import String, cast, or_, and_, func from sqlalchemy.dialects import postgresql, sqlite from sqlalchemy import BigInteger, Column, Text, JSON, Boolean @@ -133,6 +133,11 @@ class ModelResponse(ModelModel): pass +class ModelListResponse(BaseModel): + items: list[ModelUserResponse] + total: int + + class ModelForm(BaseModel): id: str base_model_id: Optional[str] = None @@ -215,6 +220,84 @@ class ModelsTable: or has_access(user_id, permission, model.access_control, user_group_ids) ] + def search_models( + self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30 + ) -> ModelListResponse: + with get_db() as db: + # Join GroupMember so we can order by group_id when requested + query = db.query(Model, User).outerjoin(User, User.id == Model.user_id) + query = query.filter(Model.base_model_id != None) + + if filter: + query_key = filter.get("query") + if query_key: + query = query.filter( + or_( + Model.name.ilike(f"%{query_key}%"), + Model.base_model_id.ilike(f"%{query_key}%"), + ) + ) + + if filter.get("user_id"): + query = query.filter(Model.user_id == filter.get("user_id")) + + view_option = filter.get("view_option") + + if view_option == "created": + query = query.filter(Model.user_id == user_id) + elif view_option == "shared": + query = query.filter(Model.user_id != user_id) + + tag = filter.get("tag") + if tag: + # TODO: This is a simple implementation and should be improved for performance + like_pattern = f'%"{tag.lower()}"%' # `"tag"` inside JSON array + meta_text = func.lower(cast(Model.meta, String)) + + query = query.filter(meta_text.like(like_pattern)) + + order_by = filter.get("order_by") + direction = filter.get("direction") + + if order_by == "name": + if direction == "asc": + query = query.order_by(Model.name.asc()) + else: + query = query.order_by(Model.name.desc()) + elif order_by == "created_at": + if direction == "asc": + query = query.order_by(Model.created_at.asc()) + else: + query = query.order_by(Model.created_at.desc()) + elif order_by == "updated_at": + if direction == "asc": + query = query.order_by(Model.updated_at.asc()) + else: + query = query.order_by(Model.updated_at.desc()) + + else: + query = query.order_by(Model.created_at.desc()) + + # Count BEFORE pagination + total = query.count() + + if skip: + query = query.offset(skip) + if limit: + query = query.limit(limit) + + items = query.all() + + models = [] + for model, user in items: + model_model = ModelModel.model_validate(model) + user_model = UserResponse(**UserModel.model_validate(user).model_dump()) + models.append( + ModelUserResponse(**model_model.model_dump(), user=user_model) + ) + + return ModelListResponse(items=models, total=total) + def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: with get_db() as db: diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index a689d26e98..d997744f58 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -9,7 +9,7 @@ from open_webui.models.models import ( ModelForm, ModelModel, ModelResponse, - ModelUserResponse, + ModelListResponse, Models, ) @@ -44,14 +44,43 @@ def is_valid_model_id(model_id: str) -> bool: ########################### +PAGE_ITEM_COUNT = 30 + + @router.get( - "/list", response_model=list[ModelUserResponse] + "/list", response_model=ModelListResponse ) # do NOT use "/" as path, conflicts with main.py -async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)): - if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: - return Models.get_models() - else: - return Models.get_models_by_user_id(user.id) +async def get_models( + query: Optional[str] = None, + view_option: Optional[str] = None, + tag: Optional[str] = None, + order_by: Optional[str] = None, + direction: Optional[str] = None, + page: Optional[int] = 1, + user=Depends(get_verified_user), +): + + limit = PAGE_ITEM_COUNT + + page = max(1, page) + skip = (page - 1) * limit + + filter = {} + if query: + filter["query"] = query + if view_option: + filter["view_option"] = view_option + if tag: + filter["tag"] = tag + if order_by: + filter["order_by"] = order_by + if direction: + filter["direction"] = direction + + if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: + filter["user_id"] = user.id + + return Models.search_models(user.id, filter=filter, skip=skip, limit=limit) ########################### @@ -64,6 +93,29 @@ async def get_base_models(user=Depends(get_admin_user)): return Models.get_base_models() +########################### +# GetModelTags +########################### + + +@router.get("/tags", response_model=list[str]) +async def get_model_tags(user=Depends(get_verified_user)): + if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: + models = Models.get_models() + else: + models = Models.get_models_by_user_id(user.id) + + tags_set = set() + for model in models: + if model.meta and model.meta.tags: + for tag in model.meta.tags: + tags_set.add((tag.get("name"))) + + tags = [tag for tag in tags_set] + tags.sort() + return tags + + ############################ # CreateNewModel ############################ diff --git a/src/lib/apis/models/index.ts b/src/lib/apis/models/index.ts index a9385088f7..d03a83e9ca 100644 --- a/src/lib/apis/models/index.ts +++ b/src/lib/apis/models/index.ts @@ -1,9 +1,68 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; -export const getModelItems = async (token: string = '') => { +export const getModelItems = async ( + token: string = '', + query, + viewOption, + selectedTag, + orderBy, + direction, + page +) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/models/list`, { + const searchParams = new URLSearchParams(); + if (query) { + searchParams.append('query', query); + } + if (viewOption) { + searchParams.append('view_option', viewOption); + } + if (selectedTag) { + searchParams.append('tag', selectedTag); + } + if (orderBy) { + searchParams.append('order_by', orderBy); + } + if (direction) { + searchParams.append('direction', direction); + } + if (page) { + searchParams.append('page', page.toString()); + } + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/list?${searchParams.toString()}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getModelTags = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/tags`, { method: 'GET', headers: { Accept: 'application/json', diff --git a/src/lib/components/workspace/Models.svelte b/src/lib/components/workspace/Models.svelte index 0343b7a5bb..72f1417924 100644 --- a/src/lib/components/workspace/Models.svelte +++ b/src/lib/components/workspace/Models.svelte @@ -17,6 +17,7 @@ createNewModel, deleteModelById, getModelItems as getWorkspaceModels, + getModelTags, toggleModelById, updateModelById } from '$lib/apis/models'; @@ -41,6 +42,7 @@ import Eye from '../icons/Eye.svelte'; import ViewSelector from './common/ViewSelector.svelte'; import TagSelector from './common/TagSelector.svelte'; + import Pagination from '../common/Pagination.svelte'; let shiftKey = false; @@ -50,41 +52,61 @@ let loaded = false; - let models = []; - let tags = []; - - let viewOption = ''; - let selectedTag = ''; - - let filteredModels = []; - let selectedModel = null; - let showModelDeleteConfirm = false; - let group_ids = []; + let selectedModel = null; - $: if (models && query !== undefined && selectedTag !== undefined && viewOption !== undefined) { - setFilteredModels(); - } + let groupIds = []; - const setFilteredModels = async () => { - filteredModels = models.filter((m) => { - if (query === '' && selectedTag === '' && viewOption === '') return true; - const lowerQuery = query.toLowerCase(); - return ( - ((m.name || '').toLowerCase().includes(lowerQuery) || - (m.user?.name || '').toLowerCase().includes(lowerQuery) || // Search by user name - (m.user?.email || '').toLowerCase().includes(lowerQuery)) && // Search by user email - (selectedTag === '' || - m?.meta?.tags?.some((tag) => tag.name.toLowerCase() === selectedTag.toLowerCase())) && - (viewOption === '' || - (viewOption === 'created' && m.user_id === $user?.id) || - (viewOption === 'shared' && m.user_id !== $user?.id)) - ); - }); - }; + let tags = []; + let selectedTag = ''; let query = ''; + let viewOption = ''; + + let page = 1; + let models = null; + let total = null; + + $: if ( + page !== undefined && + query !== undefined && + selectedTag !== undefined && + viewOption !== undefined + ) { + getModelList(); + } + + const getModelList = async () => { + try { + const res = await getWorkspaceModels( + localStorage.token, + query, + viewOption, + selectedTag, + null, + null, + page + ).catch((error) => { + toast.error(`${error}`); + return null; + }); + + if (res) { + models = res.items; + total = res.total; + + // get tags + tags = await getModelTags(localStorage.token).catch((error) => { + toast.error(`${error}`); + return []; + }); + } + } catch (err) { + console.error(err); + } + }; + const deleteModelHandler = async (model) => { const res = await deleteModelById(localStorage.token, model.id).catch((e) => { toast.error(`${e}`); @@ -93,6 +115,9 @@ if (res) { toast.success($i18n.t(`Deleted {{name}}`, { name: model.id })); + + page = 1; + getModelList(); } await _models.set( @@ -101,7 +126,6 @@ $config?.features?.enable_direct_connections && ($settings?.directConnections ?? null) ) ); - models = await getWorkspaceModels(localStorage.token); }; const cloneModelHandler = async (model) => { @@ -148,6 +172,9 @@ status: model.meta.hidden ? 'hidden' : 'visible' }) ); + + page = 1; + getModelList(); } await _models.set( @@ -156,7 +183,6 @@ $config?.features?.enable_direct_connections && ($settings?.directConnections ?? null) ) ); - models = await getWorkspaceModels(localStorage.token); }; const copyLinkHandler = async (model) => { @@ -184,26 +210,15 @@ saveAs(blob, `${model.id}-${Date.now()}.json`); }; - const setTags = () => { - if (models) { - tags = models - .filter((model) => !(model?.meta?.hidden ?? false)) - .flatMap((model) => model?.meta?.tags ?? []) - .map((tag) => tag.name); - - // Remove duplicates and sort - tags = Array.from(new Set(tags)).sort((a, b) => a.localeCompare(b)); - } - }; - onMount(async () => { viewOption = localStorage.workspaceViewOption ?? ''; + page = 1; + + await getModelList(); - models = await getWorkspaceModels(localStorage.token); let groups = await getGroups(localStorage.token); - group_ids = groups.map((group) => group.id); + groupIds = groups.map((group) => group.id); - setTags(); loaded = true; const onKeyDown = (event) => { @@ -261,8 +276,9 @@ let reader = new FileReader(); reader.onload = async (event) => { + let savedModels = []; try { - let savedModels = JSON.parse(event.target.result); + savedModels = JSON.parse(event.target.result); console.log(savedModels); } catch (e) { toast.error($i18n.t('Invalid JSON file')); @@ -273,16 +289,19 @@ if (model?.info ?? false) { if ($_models.find((m) => m.id === model.id)) { await updateModelById(localStorage.token, model.id, model.info).catch((error) => { + toast.error(`${error}`); return null; }); } else { await createNewModel(localStorage.token, model.info).catch((error) => { + toast.error(`${error}`); return null; }); } } else { if (model?.id && model?.name) { await createNewModel(localStorage.token, model).catch((error) => { + toast.error(`${error}`); return null; }); } @@ -295,7 +314,9 @@ $config?.features?.enable_direct_connections && ($settings?.directConnections ?? null) ) ); - models = await getWorkspaceModels(localStorage.token); + + page = 1; + getModelList(); }; reader.readAsText(importFiles[0]); @@ -308,7 +329,7 @@
- {filteredModels.length} + {total}
@@ -326,7 +347,7 @@ {/if} - {#if models.length && ($user?.role === 'admin' || $user?.permissions?.workspace?.models_export)} + {#if total && ($user?.role === 'admin' || $user?.permissions?.workspace?.models_export)}