From 7c81509804e280cddcd0188d21e9ffb0caa8c242 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 18 Aug 2024 20:59:59 +0200 Subject: [PATCH] feat: merge responses --- backend/constants.py | 1 + backend/main.py | 53 +++++++++ backend/utils/task.py | 37 ++++++ src/app.css | 4 + src/lib/apis/index.ts | 36 ++++++ src/lib/components/chat/Chat.svelte | 53 ++++++++- src/lib/components/chat/Messages.svelte | 2 + .../components/chat/Messages/Citations.svelte | 56 +++++++++ src/lib/components/chat/Messages/Error.svelte | 26 +++++ .../components/chat/Messages/Markdown.svelte | 32 ++++++ .../Messages/MultiResponseMessages.svelte | 86 +++++++++++--- .../chat/Messages/ResponseMessage.svelte | 107 ++---------------- .../chat/Messages/UserMessage.svelte | 14 +-- 13 files changed, 378 insertions(+), 129 deletions(-) create mode 100644 src/lib/components/chat/Messages/Citations.svelte create mode 100644 src/lib/components/chat/Messages/Error.svelte create mode 100644 src/lib/components/chat/Messages/Markdown.svelte diff --git a/backend/constants.py b/backend/constants.py index b9c7fc430d..d55216bb5d 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -100,3 +100,4 @@ class TASKS(str, Enum): EMOJI_GENERATION = "emoji_generation" QUERY_GENERATION = "query_generation" FUNCTION_CALLING = "function_calling" + MOA_RESPONSE_GENERATION = "moa_response_generation" diff --git a/backend/main.py b/backend/main.py index 6c75164dcb..d539834eda 100644 --- a/backend/main.py +++ b/backend/main.py @@ -73,6 +73,7 @@ from utils.task import ( title_generation_template, search_query_generation_template, tools_function_calling_generation_template, + moa_response_generation_template, ) from utils.misc import ( get_last_user_message, @@ -1570,6 +1571,58 @@ Message: """{{prompt}}""" return await generate_chat_completions(form_data=payload, user=user) +@app.post("/api/task/moa/completions") +async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)): + print("generate_moa_response") + + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + model_id = get_task_model_id(model_id) + print(model_id) + + template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" + +Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. + +Responses from models: {{responses}}""" + + content = moa_response_generation_template( + template, + form_data["prompt"], + form_data["responses"], + ) + + payload = { + "model": model_id, + "messages": [{"role": "user", "content": content}], + "stream": form_data.get("stream", False), + "chat_id": form_data.get("chat_id", None), + "metadata": {"task": str(TASKS.MOA_RESPONSE_GENERATION)}, + } + + log.debug(payload) + + try: + payload = filter_pipeline(payload, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + + if "chat_id" in payload: + del payload["chat_id"] + + return await generate_chat_completions(form_data=payload, user=user) + + ################################## # # Pipelines Endpoints diff --git a/backend/utils/task.py b/backend/utils/task.py index 1b2276c9c5..ea9254c4f7 100644 --- a/backend/utils/task.py +++ b/backend/utils/task.py @@ -121,6 +121,43 @@ def search_query_generation_template( return template +def moa_response_generation_template( + template: str, prompt: str, responses: list[str] +) -> str: + def replacement_function(match): + full_match = match.group(0) + start_length = match.group(1) + end_length = match.group(2) + middle_length = match.group(3) + + if full_match == "{{prompt}}": + return prompt + elif start_length is not None: + return prompt[: int(start_length)] + elif end_length is not None: + return prompt[-int(end_length) :] + elif middle_length is not None: + middle_length = int(middle_length) + if len(prompt) <= middle_length: + return prompt + start = prompt[: math.ceil(middle_length / 2)] + end = prompt[-math.floor(middle_length / 2) :] + return f"{start}...{end}" + return "" + + template = re.sub( + r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}", + replacement_function, + template, + ) + + responses = [f'"""{response}"""' for response in responses] + responses = "\n\n".join(responses) + + template = template.replace("{{responses}}", responses) + return template + + def tools_function_calling_generation_template(template: str, tools_specs: str) -> str: template = template.replace("{{TOOLS}}", tools_specs) return template diff --git a/src/app.css b/src/app.css index 4345bb3777..a421d90ae4 100644 --- a/src/app.css +++ b/src/app.css @@ -34,6 +34,10 @@ math { @apply rounded-lg; } +.markdown-prose { + @apply prose dark:prose-invert prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; +} + .markdown a { @apply underline; } diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index c4778cadbd..fc01c209dd 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -333,6 +333,42 @@ export const generateSearchQuery = async ( return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? prompt; }; +export const generateMoACompletion = async ( + token: string = '', + model: string, + prompt: string, + responses: string[] +) => { + const controller = new AbortController(); + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/task/moa/completions`, { + signal: controller.signal, + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + prompt: prompt, + responses: responses, + stream: true + }) + }).catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return [res, controller]; +}; + export const getPipelinesList = async (token: string = '') => { let error = null; diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 3da6eab037..2703d65786 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -54,7 +54,13 @@ import { createOpenAITextStream } from '$lib/apis/streaming'; import { queryMemory } from '$lib/apis/memories'; import { getAndUpdateUserLocation, getUserSettings } from '$lib/apis/users'; - import { chatCompleted, generateTitle, generateSearchQuery, chatAction } from '$lib/apis'; + import { + chatCompleted, + generateTitle, + generateSearchQuery, + chatAction, + generateMoACompletion + } from '$lib/apis'; import Banner from '../common/Banner.svelte'; import MessageInput from '$lib/components/chat/MessageInput.svelte'; @@ -1511,6 +1517,50 @@ return []; }); }; + + const mergeResponses = async (messageId, responses) => { + console.log('mergeResponses', messageId, responses); + const message = history.messages[messageId]; + const mergedResponse = { + status: true, + content: '' + }; + + message.merged = mergedResponse; + try { + const [res, controller] = await generateMoACompletion( + localStorage.token, + message.model, + history.messages[message.parentId].content, + responses + ); + + if (res && res.ok && res.body) { + const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks); + for await (const update of textStream) { + const { value, done, citations, error, usage } = update; + if (error || done) { + break; + } + + if (mergedResponse.content == '' && value == '\n') { + continue; + } else { + mergedResponse.content += value; + messages = messages; + } + + if (autoScroll) { + scrollToBottom(); + } + } + } else { + console.error(res); + } + } catch (e) { + console.error(e); + } + }; @@ -1637,6 +1687,7 @@ {sendPrompt} {continueGeneration} {regenerateResponse} + {mergeResponses} {chatActionHandler} /> diff --git a/src/lib/components/chat/Messages.svelte b/src/lib/components/chat/Messages.svelte index e1e30059e9..512014cb42 100644 --- a/src/lib/components/chat/Messages.svelte +++ b/src/lib/components/chat/Messages.svelte @@ -19,6 +19,7 @@ export let sendPrompt: Function; export let continueGeneration: Function; export let regenerateResponse: Function; + export let mergeResponses: Function; export let chatActionHandler: Function; export let user = $_user; @@ -374,6 +375,7 @@ {rateMessage} copyToClipboard={copyToClipboardWithToast} {continueGeneration} + {mergeResponses} {regenerateResponse} on:change={async () => { await updateChatById(localStorage.token, chatId, { diff --git a/src/lib/components/chat/Messages/Citations.svelte b/src/lib/components/chat/Messages/Citations.svelte new file mode 100644 index 0000000000..8112d37f49 --- /dev/null +++ b/src/lib/components/chat/Messages/Citations.svelte @@ -0,0 +1,56 @@ + + + + +
+ {#each citations.reduce((acc, citation) => { + citation.document.forEach((document, index) => { + const metadata = citation.metadata?.[index]; + const id = metadata?.source ?? 'N/A'; + let source = citation?.source; + + if (metadata?.name) { + source = { ...source, name: metadata.name }; + } + + // Check if ID looks like a URL + if (id.startsWith('http://') || id.startsWith('https://')) { + source = { name: id }; + } + + const existingSource = acc.find((item) => item.id === id); + + if (existingSource) { + existingSource.document.push(document); + existingSource.metadata.push(metadata); + } else { + acc.push( { id: id, source: source, document: [document], metadata: metadata ? [metadata] : [] } ); + } + }); + return acc; + }, []) as citation, idx} +
+ +
+ {/each} +
diff --git a/src/lib/components/chat/Messages/Error.svelte b/src/lib/components/chat/Messages/Error.svelte new file mode 100644 index 0000000000..a1fed2f421 --- /dev/null +++ b/src/lib/components/chat/Messages/Error.svelte @@ -0,0 +1,26 @@ + + +
+ + + + +
+ {content} +
+
diff --git a/src/lib/components/chat/Messages/Markdown.svelte b/src/lib/components/chat/Messages/Markdown.svelte new file mode 100644 index 0000000000..2c2f74d768 --- /dev/null +++ b/src/lib/components/chat/Messages/Markdown.svelte @@ -0,0 +1,32 @@ + + +{#key id} + +{/key} diff --git a/src/lib/components/chat/Messages/MultiResponseMessages.svelte b/src/lib/components/chat/Messages/MultiResponseMessages.svelte index d0be97d82c..7a760f07a3 100644 --- a/src/lib/components/chat/Messages/MultiResponseMessages.svelte +++ b/src/lib/components/chat/Messages/MultiResponseMessages.svelte @@ -1,11 +1,21 @@ - - {#key message.id}
{/if} -
+
{#if (message?.statusHistory ?? [...(message?.status ? [message?.status] : [])]).length > 0} {@const status = ( @@ -408,82 +382,15 @@ {:else if message.content && message.error !== true} - {#key message.id} - - {/key} + {/if} {#if message.error} -
- - - - -
- {message?.error?.content ?? message.content} -
-
+ {/if} {#if message.citations} -
- {#each message.citations.reduce((acc, citation) => { - citation.document.forEach((document, index) => { - const metadata = citation.metadata?.[index]; - const id = metadata?.source ?? 'N/A'; - let source = citation?.source; - - if (metadata?.name) { - source = { ...source, name: metadata.name }; - } - - // Check if ID looks like a URL - if (id.startsWith('http://') || id.startsWith('https://')) { - source = { name: id }; - } - - const existingSource = acc.find((item) => item.id === id); - - if (existingSource) { - existingSource.document.push(document); - existingSource.metadata.push(metadata); - } else { - acc.push( { id: id, source: source, document: [document], metadata: metadata ? [metadata] : [] } ); - } - }); - return acc; - }, []) as citation, idx} -
- -
- {/each} -
+ {/if}
{/if} diff --git a/src/lib/components/chat/Messages/UserMessage.svelte b/src/lib/components/chat/Messages/UserMessage.svelte index 11d14523f8..67f682520f 100644 --- a/src/lib/components/chat/Messages/UserMessage.svelte +++ b/src/lib/components/chat/Messages/UserMessage.svelte @@ -13,6 +13,7 @@ import { marked } from 'marked'; import { processResponseContent, replaceTokens } from '$lib/utils'; import MarkdownTokens from './Markdown/MarkdownTokens.svelte'; + import Markdown from './Markdown.svelte'; const i18n = getContext('i18n'); @@ -93,9 +94,7 @@
{/if} -
+
{#if message.files}
{#each message.files as file} @@ -174,14 +173,7 @@ : ' w-full'}" > {#if message.content} -
- {#key message.id} - - {/key} -
+ {/if}