diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index 88cecc940a..f45cf0d12e 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -317,19 +317,31 @@ class GenerateImageForm(BaseModel): def save_b64_image(b64_str): try: - header, encoded = b64_str.split(",", 1) - mime_type = header.split(";")[0] - - img_data = base64.b64decode(encoded) - image_id = str(uuid.uuid4()) - image_format = mimetypes.guess_extension(mime_type) - image_filename = f"{image_id}{image_format}" - file_path = IMAGE_CACHE_DIR / f"{image_filename}" - with open(file_path, "wb") as f: - f.write(img_data) - return image_filename + if "," in b64_str: + header, encoded = b64_str.split(",", 1) + mime_type = header.split(";")[0] + + img_data = base64.b64decode(encoded) + image_format = mimetypes.guess_extension(mime_type) + + image_filename = f"{image_id}{image_format}" + file_path = IMAGE_CACHE_DIR / f"{image_filename}" + with open(file_path, "wb") as f: + f.write(img_data) + return image_filename + else: + image_filename = f"{image_id}.png" + file_path = IMAGE_CACHE_DIR.joinpath(image_filename) + + img_data = base64.b64decode(b64_str) + + # Write the image data to a file + with open(file_path, "wb") as f: + f.write(img_data) + return image_filename + except Exception as e: log.exception(f"Error saving image: {e}") return None @@ -348,18 +360,20 @@ def save_url_image(url): if not image_format: raise ValueError("Could not determine image type from MIME type") - file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}{image_format}") + image_filename = f"{image_id}{image_format}" + + file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}") with open(file_path, "wb") as image_file: for chunk in r.iter_content(chunk_size=8192): image_file.write(chunk) - return image_id, image_format + return image_filename else: log.error(f"Url does not point to an image.") - return None, None + return None except Exception as e: log.exception(f"Error saving image: {e}") - return None, None + return None @app.post("/generations") @@ -400,7 +414,7 @@ def generate_image( for image in res["data"]: image_filename = save_b64_image(image["b64_json"]) images.append({"url": f"/cache/image/generations/{image_filename}"}) - file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") + file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json") with open(file_body_path, "w") as f: json.dump(data, f) @@ -435,11 +449,9 @@ def generate_image( images = [] for image in res["data"]: - image_id, image_format = save_url_image(image["url"]) - images.append( - {"url": f"/cache/image/generations/{image_id}{image_format}"} - ) - file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") + image_filename = save_url_image(image["url"]) + images.append({"url": f"/cache/image/generations/{image_filename}"}) + file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json") with open(file_body_path, "w") as f: json.dump(data.model_dump(exclude_none=True), f) @@ -477,7 +489,7 @@ def generate_image( for image in res["images"]: image_filename = save_b64_image(image) images.append({"url": f"/cache/image/generations/{image_filename}"}) - file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") + file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json") with open(file_body_path, "w") as f: json.dump({**data, "info": res["info"]}, f) diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 2a37622f1c..95f4420672 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -36,6 +36,10 @@ from config import ( LITELLM_PROXY_HOST, ) +import warnings + +warnings.simplefilter("ignore") + from litellm.utils import get_llm_provider import asyncio diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index f1c836faac..6d7e4f8152 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -31,7 +31,12 @@ from typing import Optional, List, Union from apps.web.models.users import Users from constants import ERROR_MESSAGES -from utils.utils import decode_token, get_current_user, get_admin_user +from utils.utils import ( + decode_token, + get_current_user, + get_verified_user, + get_admin_user, +) from config import ( @@ -164,7 +169,7 @@ async def get_all_models(): @app.get("/api/tags") @app.get("/api/tags/{url_idx}") async def get_ollama_tags( - url_idx: Optional[int] = None, user=Depends(get_current_user) + url_idx: Optional[int] = None, user=Depends(get_verified_user) ): if url_idx == None: models = await get_all_models() @@ -563,7 +568,7 @@ async def delete_model( @app.post("/api/show") -async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_user)): +async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)): if form_data.name not in app.state.MODELS: raise HTTPException( status_code=400, @@ -612,7 +617,7 @@ class GenerateEmbeddingsForm(BaseModel): async def generate_embeddings( form_data: GenerateEmbeddingsForm, url_idx: Optional[int] = None, - user=Depends(get_current_user), + user=Depends(get_verified_user), ): if url_idx == None: model = form_data.model @@ -730,7 +735,7 @@ class GenerateCompletionForm(BaseModel): async def generate_completion( form_data: GenerateCompletionForm, url_idx: Optional[int] = None, - user=Depends(get_current_user), + user=Depends(get_verified_user), ): if url_idx == None: @@ -833,7 +838,7 @@ class GenerateChatCompletionForm(BaseModel): async def generate_chat_completion( form_data: GenerateChatCompletionForm, url_idx: Optional[int] = None, - user=Depends(get_current_user), + user=Depends(get_verified_user), ): if url_idx == None: @@ -942,7 +947,7 @@ class OpenAIChatCompletionForm(BaseModel): async def generate_openai_chat_completion( form_data: OpenAIChatCompletionForm, url_idx: Optional[int] = None, - user=Depends(get_current_user), + user=Depends(get_verified_user), ): if url_idx == None: @@ -1241,7 +1246,9 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) -async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)): +async def deprecated_proxy( + path: str, request: Request, user=Depends(get_verified_user) +): url = app.state.OLLAMA_BASE_URLS[0] target_url = f"{url}/{path}" diff --git a/backend/apps/web/routers/chats.py b/backend/apps/web/routers/chats.py index 73b15e04ae..f72ed79b37 100644 --- a/backend/apps/web/routers/chats.py +++ b/backend/apps/web/routers/chats.py @@ -93,6 +93,31 @@ async def get_archived_session_user_chat_list( return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit) +############################ +# GetSharedChatById +############################ + + +@router.get("/share/{share_id}", response_model=Optional[ChatResponse]) +async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)): + if user.role == "pending": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if user.role == "user": + chat = Chats.get_chat_by_share_id(share_id) + elif user.role == "admin": + chat = Chats.get_chat_by_id(share_id) + + if chat: + return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND + ) + + ############################ # GetChats ############################ @@ -141,6 +166,55 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): ) +############################ +# GetChatsByTags +############################ + + +class TagNameForm(BaseModel): + name: str + skip: Optional[int] = 0 + limit: Optional[int] = 50 + + +@router.post("/tags", response_model=List[ChatTitleIdResponse]) +async def get_user_chat_list_by_tag_name( + form_data: TagNameForm, user=Depends(get_current_user) +): + + print(form_data) + chat_ids = [ + chat_id_tag.chat_id + for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id( + form_data.name, user.id + ) + ] + + chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit) + + if len(chats) == 0: + Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id) + + return chats + + +############################ +# GetAllTags +############################ + + +@router.get("/tags/all", response_model=List[TagModel]) +async def get_all_tags(user=Depends(get_current_user)): + try: + tags = Tags.get_tags_by_user_id(user.id) + return tags + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + ############################ # GetChatById ############################ @@ -274,70 +348,6 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)): ) -############################ -# GetSharedChatById -############################ - - -@router.get("/share/{share_id}", response_model=Optional[ChatResponse]) -async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)): - if user.role == "pending": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND - ) - - if user.role == "user": - chat = Chats.get_chat_by_share_id(share_id) - elif user.role == "admin": - chat = Chats.get_chat_by_id(share_id) - - if chat: - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND - ) - - -############################ -# GetAllTags -############################ - - -@router.get("/tags/all", response_model=List[TagModel]) -async def get_all_tags(user=Depends(get_current_user)): - try: - tags = Tags.get_tags_by_user_id(user.id) - return tags - except Exception as e: - log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) - - -############################ -# GetChatsByTags -############################ - - -@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse]) -async def get_user_chat_list_by_tag_name( - tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50 -): - chat_ids = [ - chat_id_tag.chat_id - for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id) - ] - - chats = Chats.get_chat_list_by_chat_ids(chat_ids, skip, limit) - - if len(chats) == 0: - Tags.delete_tag_by_tag_name_and_user_id(tag_name, user.id) - - return chats - - ############################ # GetChatTagsById ############################ diff --git a/backend/main.py b/backend/main.py index e37ac324b9..e36b8296af 100644 --- a/backend/main.py +++ b/backend/main.py @@ -25,6 +25,8 @@ from apps.litellm.main import ( start_litellm_background, shutdown_litellm_background, ) + + from apps.audio.main import app as audio_app from apps.images.main import app as images_app from apps.rag.main import app as rag_app @@ -74,7 +76,7 @@ class SPAStaticFiles(StaticFiles): print( - f""" + rf""" ___ __ __ _ _ _ ___ / _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _| | | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || | diff --git a/cypress/e2e/chat.cy.ts b/cypress/e2e/chat.cy.ts index fce7862727..f46bef57b3 100644 --- a/cypress/e2e/chat.cy.ts +++ b/cypress/e2e/chat.cy.ts @@ -21,14 +21,14 @@ describe('Settings', () => { // Click on the model selector cy.get('button[aria-label="Select a model"]').click(); // Select the first model - cy.get('div[role="option"][data-value]').first().click(); + cy.get('button[aria-label="model-item"]').first().click(); }); it('user can perform text chat', () => { // Click on the model selector cy.get('button[aria-label="Select a model"]').click(); // Select the first model - cy.get('div[role="option"][data-value]').first().click(); + cy.get('button[aria-label="model-item"]').first().click(); // Type a message cy.get('#chat-textarea').type('Hi, what can you do? A single sentence only please.', { force: true diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts index 7b76b11fa9..a72b519397 100644 --- a/src/lib/apis/chats/index.ts +++ b/src/lib/apis/chats/index.ts @@ -1,4 +1,5 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; +import { getTimeRange } from '$lib/utils'; export const createNewChat = async (token: string, chat: object) => { let error = null; @@ -59,7 +60,10 @@ export const getChatList = async (token: string = '') => { throw error; } - return res; + return res.map((chat) => ({ + ...chat, + time_range: getTimeRange(chat.updated_at) + })); }; export const getChatListByUserId = async (token: string = '', userId: string) => { @@ -90,7 +94,10 @@ export const getChatListByUserId = async (token: string = '', userId: string) => throw error; } - return res; + return res.map((chat) => ({ + ...chat, + time_range: getTimeRange(chat.updated_at) + })); }; export const getArchivedChatList = async (token: string = '') => { @@ -220,13 +227,16 @@ export const getAllChatTags = async (token: string) => { export const getChatListByTagName = async (token: string = '', tagName: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/tag/${tagName}`, { - method: 'GET', + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags`, { + method: 'POST', headers: { Accept: 'application/json', 'Content-Type': 'application/json', ...(token && { authorization: `Bearer ${token}` }) - } + }, + body: JSON.stringify({ + name: tagName + }) }) .then(async (res) => { if (!res.ok) throw await res.json(); @@ -245,7 +255,10 @@ export const getChatListByTagName = async (token: string = '', tagName: string) throw error; } - return res; + return res.map((chat) => ({ + ...chat, + time_range: getTimeRange(chat.updated_at) + })); }; export const getChatById = async (token: string, id: string) => { diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 7d377d486e..6711ea2b5d 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -411,7 +411,9 @@ {#if dragged}
{ if ( + window.innerWidth > 1024 || !( 'ontouchstart' in window || navigator.maxTouchPoints > 0 || diff --git a/src/lib/components/chat/ModelSelector/Selector.svelte b/src/lib/components/chat/ModelSelector/Selector.svelte index 6775ecc9e5..0f1ebfbf96 100644 --- a/src/lib/components/chat/ModelSelector/Selector.svelte +++ b/src/lib/components/chat/ModelSelector/Selector.svelte @@ -25,7 +25,9 @@ export let items = [{ value: 'mango', label: 'Mango' }]; - export let className = 'max-w-lg'; + export let className = ' w-[32rem]'; + + let show = false; let selectedModel = ''; $: selectedModel = items.find((item) => item.value === value) ?? ''; @@ -181,6 +183,7 @@ { searchValue = ''; window.setTimeout(() => document.getElementById('model-search-input')?.focus(), 0); @@ -199,7 +202,7 @@
@@ -222,10 +226,13 @@
{#each filteredItems as item} - { value = item.value; + + show = false; }} >
@@ -294,7 +301,7 @@
{/if} -
+ {:else}
@@ -392,6 +399,9 @@
{/each}
+ +
+ +
+
{$i18n.t('Input commands')}
+
+ +
+
+
+
+
+ {$i18n.t('Attach file')} +
+ +
+
+ # +
+
+
+ +
+
+ {$i18n.t('Add custom prompt')} +
+ +
+
+ / +
+
+
+ +
+
+ {$i18n.t('Select model')} +
+ +
+
+ @ +
+
+
+
+
+
diff --git a/src/lib/components/chat/Tags.svelte b/src/lib/components/chat/Tags.svelte index 17ec064ee8..47e63198fb 100644 --- a/src/lib/components/chat/Tags.svelte +++ b/src/lib/components/chat/Tags.svelte @@ -3,11 +3,15 @@ addTagById, deleteTagById, getAllChatTags, + getChatList, + getChatListByTagName, getTagsById, updateChatById } from '$lib/apis/chats'; - import { tags as _tags } from '$lib/stores'; - import { onMount } from 'svelte'; + import { tags as _tags, chats } from '$lib/stores'; + import { createEventDispatcher, onMount } from 'svelte'; + + const dispatch = createEventDispatcher(); import Tags from '../common/Tags.svelte'; @@ -39,7 +43,21 @@ tags: tags }); - _tags.set(await getAllChatTags(localStorage.token)); + console.log($_tags); + + await _tags.set(await getAllChatTags(localStorage.token)); + + console.log($_tags); + + if ($_tags.map((t) => t.name).includes(tagName)) { + await chats.set(await getChatListByTagName(localStorage.token, tagName)); + + if ($chats.find((chat) => chat.id === chatId)) { + dispatch('close'); + } + } else { + await chats.set(await getChatList(localStorage.token)); + } }; onMount(async () => { diff --git a/src/lib/components/common/ImagePreview.svelte b/src/lib/components/common/ImagePreview.svelte index badabebda4..99882cbca4 100644 --- a/src/lib/components/common/ImagePreview.svelte +++ b/src/lib/components/common/ImagePreview.svelte @@ -51,7 +51,7 @@