diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 42a30dee1e..d9d6ee8842 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -91,6 +91,7 @@ from open_webui.routers import ( evaluations, tools, users, + user_models, utils, scim, ) @@ -1297,7 +1298,9 @@ app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"]) app.include_router(auths.router, prefix="/api/v1/auths", tags=["auths"]) app.include_router(users.router, prefix="/api/v1/users", tags=["users"]) - +app.include_router( + user_models.router, prefix="/api/v1/user/models", tags=["user_models"] +) app.include_router(channels.router, prefix="/api/v1/channels", tags=["channels"]) app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"]) app.include_router(notes.router, prefix="/api/v1/notes", tags=["notes"]) diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 31f617126a..fbdbaaa0a8 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -38,6 +38,7 @@ from open_webui.routers.pipelines import ( from open_webui.models.functions import Functions from open_webui.models.models import Models +from open_webui.models.user_model_credentials import UserModelCredentials from open_webui.utils.plugin import ( @@ -225,6 +226,23 @@ async def generate_chat_completion( # === 4. 验证模型存在性 === model_id = form_data["model"] + + # 私有模型直连:如果是用户私有模型,使用 credential_id 注入 direct 配置 + if form_data.get("is_user_model") and form_data.get("model_item", {}).get("credential_id"): + cred = UserModelCredentials.get_credential_by_id_and_user_id( + form_data["model_item"]["credential_id"], user.id + ) + if not cred: + raise Exception("User model credential not found") + request.state.direct = True + request.state.model = { + "id": cred.model_id, + "name": cred.name or cred.model_id, + "base_url": cred.base_url, + "api_key": cred.api_key, + } + models = {request.state.model["id"]: request.state.model} + model_id = cred.model_id if model_id not in models: raise Exception("Model not found") diff --git a/src/lib/apis/userModels.ts b/src/lib/apis/userModels.ts new file mode 100644 index 0000000000..df6e8965d7 --- /dev/null +++ b/src/lib/apis/userModels.ts @@ -0,0 +1,130 @@ +import { WEBUI_API_BASE_URL } from '$lib/constants'; + +export type UserModelCredentialForm = { + name?: string; + model_id: string; + base_url?: string; + api_key: string; + config?: object; +}; + +// WEBUI_API_BASE_URL 已包含 /api 或 /api/v1 前缀,这里仅追加 user/models +const BASE = `${WEBUI_API_BASE_URL}/user/models`; + +export const listUserModels = async (token: string = '') => { + let error = null; + + const res = await fetch(BASE, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const createUserModel = async (token: string, body: UserModelCredentialForm) => { + let error = null; + + const res = await fetch(BASE, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify(body) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateUserModel = async ( + token: string, + id: string, + body: UserModelCredentialForm +) => { + let error = null; + + const res = await fetch(`${BASE}/${id}`, { + method: 'PUT', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify(body) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteUserModel = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${BASE}/${id}`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 16ae5030e3..a4a76b4a6a 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -13,34 +13,35 @@ import type { i18n as i18nType } from 'i18next'; import { WEBUI_BASE_URL } from '$lib/constants'; - import { - chatId, - chats, - config, - type Model, - models, - tags as allTags, - settings, - showSidebar, - WEBUI_NAME, - banners, - user, - socket, - showControls, - showCallOverlay, - currentChatPage, - temporaryChatEnabled, - mobile, - showOverview, - chatTitle, - showArtifacts, - tools, - toolServers, - functions, - selectedFolder, - pinnedChats, - showEmbeds - } from '$lib/stores'; +import { + chatId, + chats, + config, + type Model, + models, + userModels, + tags as allTags, + settings, + showSidebar, + WEBUI_NAME, + banners, + user, + socket, + showControls, + showCallOverlay, + currentChatPage, + temporaryChatEnabled, + mobile, + showOverview, + chatTitle, + showArtifacts, + tools, + toolServers, + functions, + selectedFolder, + pinnedChats, + showEmbeds +} from '$lib/stores'; import { convertMessagesToHistory, copyToClipboard, @@ -1550,11 +1551,12 @@ let memoryLocked = false; const submitPrompt = async (userPrompt, { _raw = false } = {}) => { console.log('submitPrompt', userPrompt, $chatId); - // === 1. 模型验证:确保选中的模型仍然存在 === - // 过滤掉已被删除或不可用的模型,避免发送请求时出错 - const _selectedModels = selectedModels.map((modelId) => - $models.map((m) => m.id).includes(modelId) ? modelId : '' - ); + // === 1. 模型验证:确保选中的模型仍然存在 === + // 过滤掉已被删除或不可用的模型,避免发送请求时出错 + const _selectedModels = selectedModels.map((modelId) => { + const allIds = [...$models.map((m) => m.id), ...$userModels.map((m) => m.id)]; + return allIds.includes(modelId) ? modelId : ''; + }); // 如果模型列表发生变化,同步更新 if (JSON.stringify(selectedModels) !== JSON.stringify(_selectedModels)) { @@ -1773,9 +1775,9 @@ let memoryLocked = false; // === 4. 为每个选中的模型创建响应消息占位符 === // 这样 UI 可以立即显示"正在输入..."状态 for (const [_modelIdx, modelId] of selectedModelIds.entries()) { - const model = $models.filter((m) => m.id === modelId).at(0); - - if (model) { + const combined = getCombinedModelById(modelId); + if (combined) { + const model = combined.model ?? combined.credential; // 4.1 生成响应消息 ID 和空消息对象 let responseMessageId = uuidv4(); let responseMessage = { @@ -1784,8 +1786,14 @@ let memoryLocked = false; childrenIds: [], role: 'assistant', content: '', // 初始为空,后续通过 WebSocket 流式填充 - model: model.id, - modelName: model.name ?? model.id, + model: + combined.source === 'user' && combined.credential + ? combined.credential.model_id + : model.id, + modelName: + combined.source === 'user' && combined.credential + ? combined.credential.name ?? combined.credential.model_id + : model.name ?? model.id, modelIdx: modelIdx ? modelIdx : _modelIdx, // 多模型对话时,区分不同模型的响应 timestamp: Math.floor(Date.now() / 1000) // Unix epoch }; @@ -1829,22 +1837,27 @@ let memoryLocked = false; await Promise.all( selectedModelIds.map(async (modelId, _modelIdx) => { console.log('modelId', modelId); - const model = $models.filter((m) => m.id === modelId).at(0); + const combined = getCombinedModelById(modelId); + const model = combined?.model ?? combined?.credential; - if (model) { - // 7.1 检查模型视觉能力(如果消息包含图片) - const hasImages = createMessagesList(_history, parentId).some((message) => - message.files?.some((file) => file.type === 'image') + if (combined && model) { + // 7.1 检查模型视觉能力(如果消息包含图片) + const hasImages = createMessagesList(_history, parentId).some((message) => + message.files?.some((file) => file.type === 'image') + ); + + // 如果消息包含图片,但模型不支持视觉,提示错误(私有模型默认视为支持) + if ( + combined.source !== 'user' && + hasImages && + !(model.info?.meta?.capabilities?.vision ?? true) + ) { + toast.error( + $i18n.t('Model {{modelName}} is not vision capable', { + modelName: model.name ?? model.id + }) ); - - // 如果消息包含图片,但模型不支持视觉,提示错误 - if (hasImages && !(model.info?.meta?.capabilities?.vision ?? true)) { - toast.error( - $i18n.t('Model {{modelName}} is not vision capable', { - modelName: model.name ?? model.id - }) - ); - } + } // 7.2 获取响应消息 ID let responseMessageId = @@ -1860,12 +1873,12 @@ let memoryLocked = false; // - 构造请求 payload(messages、files、tools、features 等) // - 调用 generateOpenAIChatCompletion API // - 处理流式响应(通过 WebSocket 实时更新消息内容) - await sendMessageSocket( - model, - messages && messages.length > 0 - ? messages // 使用自定义消息列表(例如重新生成时追加 follow-up) - : createMessagesList(_history, responseMessageId), // 使用完整历史记录 - _history, + await sendMessageSocket( + combined, + messages && messages.length > 0 + ? messages // 使用自定义消息列表(例如重新生成时追加 follow-up) + : createMessagesList(_history, responseMessageId), // 使用完整历史记录 + _history, responseMessageId, _chatId ); @@ -1883,8 +1896,8 @@ let memoryLocked = false; chats.set(await getChatList(localStorage.token, $currentChatPage)); }; - const getFeatures = () => { - let features = {}; +const getFeatures = () => { + let features = {}; if ($config?.features) features = { @@ -1921,17 +1934,27 @@ let memoryLocked = false; } // 如果用户手动切换了记忆开关,覆盖全局设置 - if (memoryEnabled !== undefined && memoryEnabled !== ($settings?.memory ?? false)) { - features = { ...features, memory: memoryEnabled }; - } + if (memoryEnabled !== undefined && memoryEnabled !== ($settings?.memory ?? false)) { + features = { ...features, memory: memoryEnabled }; + } - return features; - }; + return features; +}; - const sendMessageSocket = async (model, _messages, _history, responseMessageId, _chatId) => { +const getCombinedModelById = (modelId) => { + const platform = $models.find((m) => m.id === modelId); + if (platform) return { source: 'platform', model: platform }; + const priv = $userModels.find((m) => m.id === modelId); + if (priv) return { source: 'user', credential: priv }; + return null; +}; + + const sendMessageSocket = async (combinedModel, _messages, _history, responseMessageId, _chatId) => { const responseMessage = _history.messages[responseMessageId]; const userMessage = _history.messages[responseMessage.parentId]; + const model = combinedModel?.model ?? combinedModel?.credential ?? combinedModel; + const chatMessageFiles = _messages .filter((message) => message.files) .flatMap((message) => message.files); @@ -1972,6 +1995,9 @@ let memoryLocked = false; }); } + const isUserModel = combinedModel?.source === 'user'; + const credential = combinedModel?.credential; + const stream = model?.info?.params?.stream_response ?? $settings?.params?.stream_response ?? @@ -2039,7 +2065,7 @@ let memoryLocked = false; localStorage.token, { stream: stream, - model: model.id, + model: isUserModel ? credential.model_id : model.id, messages: messages, params: { ...$settings?.params, @@ -2063,7 +2089,8 @@ let memoryLocked = false; variables: { ...getPromptVariables($user?.name, $settings?.userLocation ? userLocation : undefined) }, - model_item: $models.find((m) => m.id === model.id), + model_item: isUserModel ? { credential_id: credential.id } : $models.find((m) => m.id === model.id), + is_user_model: isUserModel, session_id: $socket?.id, chat_id: $chatId, diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte index 554f501e03..7d7fd2f668 100644 --- a/src/lib/components/chat/Messages/ResponseMessage.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage.svelte @@ -15,7 +15,15 @@ import { getChatById } from '$lib/apis/chats'; import { generateTags } from '$lib/apis'; - import { config, models, settings, temporaryChatEnabled, TTSWorker, user } from '$lib/stores'; +import { + config, + models, + settings, + temporaryChatEnabled, + TTSWorker, + user, + userModels +} from '$lib/stores'; import { synthesizeOpenAISpeech } from '$lib/apis/audio'; import { imageGenerations } from '$lib/apis/images'; import { @@ -147,8 +155,30 @@ let buttonsContainerElement: HTMLDivElement; let showDeleteConfirm = false; - let model = null; - $: model = $models.find((m) => m.id === message.model); +let model = null; +let userModel = null; +let modelName = ''; + +$: { + const platformModel = $models.find((m) => m.id === message.model); + const credentialModel = (() => { + // 优先使用随消息携带的 credential_id 精确匹配 + if (message?.model_item?.credential_id) { + return $userModels.find((m) => m.id === message.model_item.credential_id); + } + // 兼容历史消息:通过 model_id 反查 + return $userModels.find((m) => m.model_id === message.model); + })(); + + userModel = credentialModel ? { ...credentialModel, id: credentialModel.model_id } : null; + model = platformModel ?? userModel; + + modelName = + message.modelName ?? + model?.name ?? + (userModel?.name ?? userModel?.model_id) ?? + message.model; +} let edit = false; let editedContent = ''; @@ -619,10 +649,10 @@