diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 0ebb5dfb4e..10b5cbdc7b 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -89,6 +89,7 @@ import Tooltip from '../common/Tooltip.svelte'; import Sidebar from '../icons/Sidebar.svelte'; import { getFunctions } from '$lib/apis/functions'; + import Image from '../common/Image.svelte'; export let chatIdProp = ''; @@ -1787,6 +1788,23 @@ })) .filter((message) => message?.role === 'user' || message?.content?.trim()); + const toolIds = []; + const toolServerIds = []; + + for (const toolId of selectedToolIds) { + if (toolId.startsWith('direct_server:')) { + let serverId = toolId.replace('direct_server:', ''); + // Check if serverId is a number + if (!isNaN(parseInt(serverId))) { + toolServerIds.push(parseInt(serverId)); + } else { + toolServerIds.push(serverId); + } + } else { + toolIds.push(toolId); + } + } + const res = await generateOpenAIChatCompletion( localStorage.token, { @@ -1807,8 +1825,10 @@ files: (files?.length ?? 0) > 0 ? files : undefined, filter_ids: selectedFilterIds.length > 0 ? selectedFilterIds : undefined, - tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, - tool_servers: $toolServers, + tool_ids: toolIds.length > 0 ? toolIds : undefined, + tool_servers: ($toolServers ?? []).filter( + (server, idx) => toolServerIds.includes(idx) || toolServerIds.includes(server?.id) + ), features: getFeatures(), variables: { ...getPromptVariables($user?.name, $settings?.userLocation ? userLocation : undefined) diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 76dd59c396..6c0a7d8619 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -15,11 +15,11 @@ type Model, mobile, settings, - showSidebar, models, config, showCallOverlay, tools, + toolServers, user as _user, showControls, TTSWorker, @@ -45,6 +45,7 @@ import { generateAutoCompletion } from '$lib/apis'; import { deleteFileById } from '$lib/apis/files'; import { getSessionUser } from '$lib/apis/auths'; + import { getTools } from '$lib/apis/tools'; import { WEBUI_BASE_URL, WEBUI_API_BASE_URL, PASTED_TEXT_CHARACTER_LIMIT } from '$lib/constants'; @@ -99,8 +100,6 @@ export let prompt = ''; export let files = []; - export let toolServers = []; - export let selectedToolIds = []; export let selectedFilterIds = []; @@ -442,7 +441,7 @@ .reduce((acc, filters) => acc.filter((f1) => filters.some((f2) => f2.id === f1.id))); let showToolsButton = false; - $: showToolsButton = toolServers.length + selectedToolIds.length > 0; + $: showToolsButton = ($tools ?? []).length > 0 || ($toolServers ?? []).length > 0; let showWebSearchButton = false; $: showWebSearchButton = @@ -902,6 +901,8 @@ dropzoneElement?.addEventListener('dragover', onDragOver); dropzoneElement?.addEventListener('drop', onDrop); dropzoneElement?.addEventListener('dragleave', onDragLeave); + + await tools.set(await getTools(localStorage.token)); }); onDestroy(() => { @@ -1457,10 +1458,10 @@ {/if}
- {#if showToolsButton} + {#if (selectedToolIds ?? []).length > 0} diff --git a/src/lib/components/chat/MessageInput/IntegrationsMenu.svelte b/src/lib/components/chat/MessageInput/IntegrationsMenu.svelte index 9b99be7add..62258e8883 100644 --- a/src/lib/components/chat/MessageInput/IntegrationsMenu.svelte +++ b/src/lib/components/chat/MessageInput/IntegrationsMenu.svelte @@ -4,7 +4,7 @@ import { fly } from 'svelte/transition'; import { flyAndScale } from '$lib/utils/transitions'; - import { config, user, tools as _tools, mobile, settings } from '$lib/stores'; + import { config, user, tools as _tools, mobile, settings, toolServers } from '$lib/stores'; import { getTools } from '$lib/apis/tools'; @@ -55,7 +55,10 @@ ($user?.role === 'admin' || $user?.permissions?.chat?.file_upload); const init = async () => { - await _tools.set(await getTools(localStorage.token)); + if ($_tools === null) { + await _tools.set(await getTools(localStorage.token)); + } + if ($_tools) { tools = $_tools.reduce((a, tool, i, arr) => { a[tool.id] = { @@ -65,8 +68,22 @@ }; return a; }, {}); - selectedToolIds = selectedToolIds.filter((id) => $_tools?.some((tool) => tool.id === id)); } + + if ($toolServers) { + for (const serverIdx in $toolServers) { + const server = $toolServers[serverIdx]; + if (server.info) { + tools[`direct_server:${serverIdx}`] = { + name: server?.info?.title ?? server.url, + description: server.info.description ?? '', + enabled: selectedToolIds.includes(`direct_server:${serverIdx}`) + }; + } + } + } + + selectedToolIds = selectedToolIds.filter((id) => Object.keys(tools).includes(id)); };