refac: models workspace optimization

This commit is contained in:
Timothy Jaeryang Baek 2025-11-22 23:20:51 -05:00
parent f9c96d03ad
commit b2034861ae
5 changed files with 292 additions and 70 deletions

View file

@ -6,12 +6,12 @@ from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.groups import Groups from open_webui.models.groups import Groups
from open_webui.models.users import Users, UserResponse from open_webui.models.users import User, UserModel, Users, UserResponse
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import or_, and_, func from sqlalchemy import String, cast, or_, and_, func
from sqlalchemy.dialects import postgresql, sqlite from sqlalchemy.dialects import postgresql, sqlite
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
@ -133,6 +133,11 @@ class ModelResponse(ModelModel):
pass pass
class ModelListResponse(BaseModel):
items: list[ModelUserResponse]
total: int
class ModelForm(BaseModel): class ModelForm(BaseModel):
id: str id: str
base_model_id: Optional[str] = None base_model_id: Optional[str] = None
@ -215,6 +220,84 @@ class ModelsTable:
or has_access(user_id, permission, model.access_control, user_group_ids) or has_access(user_id, permission, model.access_control, user_group_ids)
] ]
def search_models(
self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30
) -> ModelListResponse:
with get_db() as db:
# Join GroupMember so we can order by group_id when requested
query = db.query(Model, User).outerjoin(User, User.id == Model.user_id)
query = query.filter(Model.base_model_id != None)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(
or_(
Model.name.ilike(f"%{query_key}%"),
Model.base_model_id.ilike(f"%{query_key}%"),
)
)
if filter.get("user_id"):
query = query.filter(Model.user_id == filter.get("user_id"))
view_option = filter.get("view_option")
if view_option == "created":
query = query.filter(Model.user_id == user_id)
elif view_option == "shared":
query = query.filter(Model.user_id != user_id)
tag = filter.get("tag")
if tag:
# TODO: This is a simple implementation and should be improved for performance
like_pattern = f'%"{tag.lower()}"%' # `"tag"` inside JSON array
meta_text = func.lower(cast(Model.meta, String))
query = query.filter(meta_text.like(like_pattern))
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by == "name":
if direction == "asc":
query = query.order_by(Model.name.asc())
else:
query = query.order_by(Model.name.desc())
elif order_by == "created_at":
if direction == "asc":
query = query.order_by(Model.created_at.asc())
else:
query = query.order_by(Model.created_at.desc())
elif order_by == "updated_at":
if direction == "asc":
query = query.order_by(Model.updated_at.asc())
else:
query = query.order_by(Model.updated_at.desc())
else:
query = query.order_by(Model.created_at.desc())
# Count BEFORE pagination
total = query.count()
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
items = query.all()
models = []
for model, user in items:
model_model = ModelModel.model_validate(model)
user_model = UserResponse(**UserModel.model_validate(user).model_dump())
models.append(
ModelUserResponse(**model_model.model_dump(), user=user_model)
)
return ModelListResponse(items=models, total=total)
def get_model_by_id(self, id: str) -> Optional[ModelModel]: def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try: try:
with get_db() as db: with get_db() as db:

View file

@ -9,7 +9,7 @@ from open_webui.models.models import (
ModelForm, ModelForm,
ModelModel, ModelModel,
ModelResponse, ModelResponse,
ModelUserResponse, ModelListResponse,
Models, Models,
) )
@ -44,14 +44,43 @@ def is_valid_model_id(model_id: str) -> bool:
########################### ###########################
PAGE_ITEM_COUNT = 30
@router.get( @router.get(
"/list", response_model=list[ModelUserResponse] "/list", response_model=ModelListResponse
) # do NOT use "/" as path, conflicts with main.py ) # do NOT use "/" as path, conflicts with main.py
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)): async def get_models(
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: query: Optional[str] = None,
return Models.get_models() view_option: Optional[str] = None,
else: tag: Optional[str] = None,
return Models.get_models_by_user_id(user.id) order_by: Optional[str] = None,
direction: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_verified_user),
):
limit = PAGE_ITEM_COUNT
page = max(1, page)
skip = (page - 1) * limit
filter = {}
if query:
filter["query"] = query
if view_option:
filter["view_option"] = view_option
if tag:
filter["tag"] = tag
if order_by:
filter["order_by"] = order_by
if direction:
filter["direction"] = direction
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
filter["user_id"] = user.id
return Models.search_models(user.id, filter=filter, skip=skip, limit=limit)
########################### ###########################
@ -64,6 +93,29 @@ async def get_base_models(user=Depends(get_admin_user)):
return Models.get_base_models() return Models.get_base_models()
###########################
# GetModelTags
###########################
@router.get("/tags", response_model=list[str])
async def get_model_tags(user=Depends(get_verified_user)):
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
models = Models.get_models()
else:
models = Models.get_models_by_user_id(user.id)
tags_set = set()
for model in models:
if model.meta and model.meta.tags:
for tag in model.meta.tags:
tags_set.add((tag.get("name")))
tags = [tag for tag in tags_set]
tags.sort()
return tags
############################ ############################
# CreateNewModel # CreateNewModel
############################ ############################

View file

@ -1,9 +1,68 @@
import { WEBUI_API_BASE_URL } from '$lib/constants'; import { WEBUI_API_BASE_URL } from '$lib/constants';
export const getModelItems = async (token: string = '') => { export const getModelItems = async (
token: string = '',
query,
viewOption,
selectedTag,
orderBy,
direction,
page
) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/models/list`, { const searchParams = new URLSearchParams();
if (query) {
searchParams.append('query', query);
}
if (viewOption) {
searchParams.append('view_option', viewOption);
}
if (selectedTag) {
searchParams.append('tag', selectedTag);
}
if (orderBy) {
searchParams.append('order_by', orderBy);
}
if (direction) {
searchParams.append('direction', direction);
}
if (page) {
searchParams.append('page', page.toString());
}
const res = await fetch(`${WEBUI_API_BASE_URL}/models/list?${searchParams.toString()}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
console.error(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getModelTags = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/models/tags`, {
method: 'GET', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',

View file

@ -17,6 +17,7 @@
createNewModel, createNewModel,
deleteModelById, deleteModelById,
getModelItems as getWorkspaceModels, getModelItems as getWorkspaceModels,
getModelTags,
toggleModelById, toggleModelById,
updateModelById updateModelById
} from '$lib/apis/models'; } from '$lib/apis/models';
@ -41,6 +42,7 @@
import Eye from '../icons/Eye.svelte'; import Eye from '../icons/Eye.svelte';
import ViewSelector from './common/ViewSelector.svelte'; import ViewSelector from './common/ViewSelector.svelte';
import TagSelector from './common/TagSelector.svelte'; import TagSelector from './common/TagSelector.svelte';
import Pagination from '../common/Pagination.svelte';
let shiftKey = false; let shiftKey = false;
@ -50,41 +52,61 @@
let loaded = false; let loaded = false;
let models = [];
let tags = [];
let viewOption = '';
let selectedTag = '';
let filteredModels = [];
let selectedModel = null;
let showModelDeleteConfirm = false; let showModelDeleteConfirm = false;
let group_ids = []; let selectedModel = null;
$: if (models && query !== undefined && selectedTag !== undefined && viewOption !== undefined) { let groupIds = [];
setFilteredModels();
}
const setFilteredModels = async () => { let tags = [];
filteredModels = models.filter((m) => { let selectedTag = '';
if (query === '' && selectedTag === '' && viewOption === '') return true;
const lowerQuery = query.toLowerCase();
return (
((m.name || '').toLowerCase().includes(lowerQuery) ||
(m.user?.name || '').toLowerCase().includes(lowerQuery) || // Search by user name
(m.user?.email || '').toLowerCase().includes(lowerQuery)) && // Search by user email
(selectedTag === '' ||
m?.meta?.tags?.some((tag) => tag.name.toLowerCase() === selectedTag.toLowerCase())) &&
(viewOption === '' ||
(viewOption === 'created' && m.user_id === $user?.id) ||
(viewOption === 'shared' && m.user_id !== $user?.id))
);
});
};
let query = ''; let query = '';
let viewOption = '';
let page = 1;
let models = null;
let total = null;
$: if (
page !== undefined &&
query !== undefined &&
selectedTag !== undefined &&
viewOption !== undefined
) {
getModelList();
}
const getModelList = async () => {
try {
const res = await getWorkspaceModels(
localStorage.token,
query,
viewOption,
selectedTag,
null,
null,
page
).catch((error) => {
toast.error(`${error}`);
return null;
});
if (res) {
models = res.items;
total = res.total;
// get tags
tags = await getModelTags(localStorage.token).catch((error) => {
toast.error(`${error}`);
return [];
});
}
} catch (err) {
console.error(err);
}
};
const deleteModelHandler = async (model) => { const deleteModelHandler = async (model) => {
const res = await deleteModelById(localStorage.token, model.id).catch((e) => { const res = await deleteModelById(localStorage.token, model.id).catch((e) => {
toast.error(`${e}`); toast.error(`${e}`);
@ -93,6 +115,9 @@
if (res) { if (res) {
toast.success($i18n.t(`Deleted {{name}}`, { name: model.id })); toast.success($i18n.t(`Deleted {{name}}`, { name: model.id }));
page = 1;
getModelList();
} }
await _models.set( await _models.set(
@ -101,7 +126,6 @@
$config?.features?.enable_direct_connections && ($settings?.directConnections ?? null) $config?.features?.enable_direct_connections && ($settings?.directConnections ?? null)
) )
); );
models = await getWorkspaceModels(localStorage.token);
}; };
const cloneModelHandler = async (model) => { const cloneModelHandler = async (model) => {
@ -148,6 +172,9 @@
status: model.meta.hidden ? 'hidden' : 'visible' status: model.meta.hidden ? 'hidden' : 'visible'
}) })
); );
page = 1;
getModelList();
} }
await _models.set( await _models.set(
@ -156,7 +183,6 @@
$config?.features?.enable_direct_connections && ($settings?.directConnections ?? null) $config?.features?.enable_direct_connections && ($settings?.directConnections ?? null)
) )
); );
models = await getWorkspaceModels(localStorage.token);
}; };
const copyLinkHandler = async (model) => { const copyLinkHandler = async (model) => {
@ -184,26 +210,15 @@
saveAs(blob, `${model.id}-${Date.now()}.json`); saveAs(blob, `${model.id}-${Date.now()}.json`);
}; };
const setTags = () => {
if (models) {
tags = models
.filter((model) => !(model?.meta?.hidden ?? false))
.flatMap((model) => model?.meta?.tags ?? [])
.map((tag) => tag.name);
// Remove duplicates and sort
tags = Array.from(new Set(tags)).sort((a, b) => a.localeCompare(b));
}
};
onMount(async () => { onMount(async () => {
viewOption = localStorage.workspaceViewOption ?? ''; viewOption = localStorage.workspaceViewOption ?? '';
page = 1;
await getModelList();
models = await getWorkspaceModels(localStorage.token);
let groups = await getGroups(localStorage.token); let groups = await getGroups(localStorage.token);
group_ids = groups.map((group) => group.id); groupIds = groups.map((group) => group.id);
setTags();
loaded = true; loaded = true;
const onKeyDown = (event) => { const onKeyDown = (event) => {
@ -261,8 +276,9 @@
let reader = new FileReader(); let reader = new FileReader();
reader.onload = async (event) => { reader.onload = async (event) => {
let savedModels = [];
try { try {
let savedModels = JSON.parse(event.target.result); savedModels = JSON.parse(event.target.result);
console.log(savedModels); console.log(savedModels);
} catch (e) { } catch (e) {
toast.error($i18n.t('Invalid JSON file')); toast.error($i18n.t('Invalid JSON file'));
@ -273,16 +289,19 @@
if (model?.info ?? false) { if (model?.info ?? false) {
if ($_models.find((m) => m.id === model.id)) { if ($_models.find((m) => m.id === model.id)) {
await updateModelById(localStorage.token, model.id, model.info).catch((error) => { await updateModelById(localStorage.token, model.id, model.info).catch((error) => {
toast.error(`${error}`);
return null; return null;
}); });
} else { } else {
await createNewModel(localStorage.token, model.info).catch((error) => { await createNewModel(localStorage.token, model.info).catch((error) => {
toast.error(`${error}`);
return null; return null;
}); });
} }
} else { } else {
if (model?.id && model?.name) { if (model?.id && model?.name) {
await createNewModel(localStorage.token, model).catch((error) => { await createNewModel(localStorage.token, model).catch((error) => {
toast.error(`${error}`);
return null; return null;
}); });
} }
@ -295,7 +314,9 @@
$config?.features?.enable_direct_connections && ($settings?.directConnections ?? null) $config?.features?.enable_direct_connections && ($settings?.directConnections ?? null)
) )
); );
models = await getWorkspaceModels(localStorage.token);
page = 1;
getModelList();
}; };
reader.readAsText(importFiles[0]); reader.readAsText(importFiles[0]);
@ -308,7 +329,7 @@
</div> </div>
<div class="text-lg font-medium text-gray-500 dark:text-gray-500"> <div class="text-lg font-medium text-gray-500 dark:text-gray-500">
{filteredModels.length} {total}
</div> </div>
</div> </div>
@ -326,7 +347,7 @@
</button> </button>
{/if} {/if}
{#if models.length && ($user?.role === 'admin' || $user?.permissions?.workspace?.models_export)} {#if total && ($user?.role === 'admin' || $user?.permissions?.workspace?.models_export)}
<button <button
class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-200 transition" class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-200 transition"
on:click={async () => { on:click={async () => {
@ -396,9 +417,7 @@
bind:value={viewOption} bind:value={viewOption}
onChange={async (value) => { onChange={async (value) => {
localStorage.workspaceViewOption = value; localStorage.workspaceViewOption = value;
await tick(); await tick();
setTags();
}} }}
/> />
@ -413,9 +432,9 @@
</div> </div>
</div> </div>
{#if (filteredModels ?? []).length !== 0} {#if (models ?? []).length !== 0}
<div class=" px-3 my-2 gap-1 lg:gap-2 grid lg:grid-cols-2" id="model-list"> <div class=" px-3 my-2 gap-1 lg:gap-2 grid lg:grid-cols-2" id="model-list">
{#each filteredModels as model (model.id)} {#each models as model (model.id)}
<!-- svelte-ignore a11y_no_static_element_interactions --> <!-- svelte-ignore a11y_no_static_element_interactions -->
<!-- svelte-ignore a11y_click_events_have_key_events --> <!-- svelte-ignore a11y_click_events_have_key_events -->
<div <div
@ -425,7 +444,7 @@
if ( if (
$user?.role === 'admin' || $user?.role === 'admin' ||
model.user_id === $user?.id || model.user_id === $user?.id ||
model.access_control.write.group_ids.some((wg) => group_ids.includes(wg)) model.access_control.write.group_ids.some((wg) => groupIds.includes(wg))
) { ) {
goto(`/workspace/models/edit?id=${encodeURIComponent(model.id)}`); goto(`/workspace/models/edit?id=${encodeURIComponent(model.id)}`);
} }
@ -607,6 +626,10 @@
</div> </div>
{/each} {/each}
</div> </div>
{#if total > 30}
<Pagination bind:page count={total} perPage={30} />
{/if}
{:else} {:else}
<div class=" w-full h-full flex flex-col justify-center items-center my-16 mb-24"> <div class=" w-full h-full flex flex-col justify-center items-center my-16 mb-24">
<div class="max-w-md text-center"> <div class="max-w-md text-center">

View file

@ -27,10 +27,15 @@
class="relative w-full flex items-center gap-0.5 px-2.5 py-1.5 rounded-xl " class="relative w-full flex items-center gap-0.5 px-2.5 py-1.5 rounded-xl "
aria-label={placeholder} aria-label={placeholder}
> >
<Select.Value <div
class="inline-flex h-input px-0.5 w-full outline-hidden bg-transparent truncate placeholder-gray-400 focus:outline-hidden capitalize" class="inline-flex h-input px-0.5 w-full outline-hidden bg-transparent truncate placeholder-gray-400 focus:outline-hidden capitalize"
>
{#if value}
{value}
{:else}
{placeholder} {placeholder}
/> {/if}
</div>
{#if value} {#if value}
<button <button