diff --git a/backend/apps/images/utils/comfyui.py b/backend/apps/images/utils/comfyui.py index f82076809e..94875d9595 100644 --- a/backend/apps/images/utils/comfyui.py +++ b/backend/apps/images/utils/comfyui.py @@ -1,6 +1,5 @@ import asyncio import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) -import uuid import json import urllib.request import urllib.parse @@ -398,7 +397,9 @@ async def comfyui_generate_image( return None try: - images = await asyncio.to_thread(get_images, ws, comfyui_prompt, client_id, base_url) + images = await asyncio.to_thread( + get_images, ws, comfyui_prompt, client_id, base_url + ) except Exception as e: log.exception(f"Error while receiving images: {e}") images = None diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 442d99ff26..d5ef829423 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -1,27 +1,21 @@ from fastapi import ( FastAPI, Request, - Response, HTTPException, Depends, - status, UploadFile, File, - BackgroundTasks, ) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse -from fastapi.concurrency import run_in_threadpool from pydantic import BaseModel, ConfigDict import os import re -import copy import random import requests import json -import uuid import aiohttp import asyncio import logging @@ -32,16 +26,11 @@ from typing import Optional, List, Union from starlette.background import BackgroundTask from apps.webui.models.models import Models -from apps.webui.models.users import Users from constants import ERROR_MESSAGES from utils.utils import ( - decode_token, - get_current_user, get_verified_user, get_admin_user, ) -from utils.task import prompt_template - from config import ( SRC_LOG_LEVELS, @@ -53,7 +42,12 @@ from config import ( UPLOAD_DIR, AppConfig, ) -from utils.misc import calculate_sha256, add_or_update_system_message +from utils.misc import ( + apply_model_params_to_body_ollama, + calculate_sha256, + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, +) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) @@ -183,7 +177,7 @@ async def post_streaming_url(url: str, payload: str, stream: bool = True): res = await r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -238,7 +232,7 @@ async def get_all_models(): async def get_ollama_tags( url_idx: Optional[int] = None, user=Depends(get_verified_user) ): - if url_idx == None: + if url_idx is None: models = await get_all_models() if app.state.config.ENABLE_MODEL_FILTER: @@ -269,7 +263,7 @@ async def get_ollama_tags( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -282,8 +276,7 @@ async def get_ollama_tags( @app.get("/api/version/{url_idx}") async def get_ollama_versions(url_idx: Optional[int] = None): if app.state.config.ENABLE_OLLAMA_API: - if url_idx == None: - + if url_idx is None: # returns lowest version tasks = [ fetch_url(f"{url}/api/version") @@ -323,7 +316,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -346,8 +339,6 @@ async def pull_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - r = None - # Admin should be able to pull models from any source payload = {**form_data.model_dump(exclude_none=True), "insecure": True} @@ -367,7 +358,7 @@ async def push_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - if url_idx == None: + if url_idx is None: if form_data.name in app.state.MODELS: url_idx = app.state.MODELS[form_data.name]["urls"][0] else: @@ -417,7 +408,7 @@ async def copy_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - if url_idx == None: + if url_idx is None: if form_data.source in app.state.MODELS: url_idx = app.state.MODELS[form_data.source]["urls"][0] else: @@ -428,13 +419,13 @@ async def copy_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + r = requests.request( + method="POST", + url=f"{url}/api/copy", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) try: - r = requests.request( - method="POST", - url=f"{url}/api/copy", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) r.raise_for_status() log.debug(f"r.text: {r.text}") @@ -448,7 +439,7 @@ async def copy_model( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -464,7 +455,7 @@ async def delete_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - if url_idx == None: + if url_idx is None: if form_data.name in app.state.MODELS: url_idx = app.state.MODELS[form_data.name]["urls"][0] else: @@ -476,12 +467,12 @@ async def delete_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + r = requests.request( + method="DELETE", + url=f"{url}/api/delete", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) try: - r = requests.request( - method="DELETE", - url=f"{url}/api/delete", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) r.raise_for_status() log.debug(f"r.text: {r.text}") @@ -495,7 +486,7 @@ async def delete_model( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -516,12 +507,12 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + r = requests.request( + method="POST", + url=f"{url}/api/show", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) try: - r = requests.request( - method="POST", - url=f"{url}/api/show", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) r.raise_for_status() return r.json() @@ -533,7 +524,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -556,7 +547,7 @@ async def generate_embeddings( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - if url_idx == None: + if url_idx is None: model = form_data.model if ":" not in model: @@ -573,12 +564,12 @@ async def generate_embeddings( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + r = requests.request( + method="POST", + url=f"{url}/api/embeddings", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) try: - r = requests.request( - method="POST", - url=f"{url}/api/embeddings", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) r.raise_for_status() return r.json() @@ -590,7 +581,7 @@ async def generate_embeddings( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -603,10 +594,9 @@ def generate_ollama_embeddings( form_data: GenerateEmbeddingsForm, url_idx: Optional[int] = None, ): - log.info(f"generate_ollama_embeddings {form_data}") - if url_idx == None: + if url_idx is None: model = form_data.model if ":" not in model: @@ -623,12 +613,12 @@ def generate_ollama_embeddings( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + r = requests.request( + method="POST", + url=f"{url}/api/embeddings", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) try: - r = requests.request( - method="POST", - url=f"{url}/api/embeddings", - data=form_data.model_dump_json(exclude_none=True).encode(), - ) r.raise_for_status() data = r.json() @@ -638,7 +628,7 @@ def generate_ollama_embeddings( if "embedding" in data: return data["embedding"] else: - raise "Something went wrong :/" + raise Exception("Something went wrong :/") except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" @@ -647,10 +637,10 @@ def generate_ollama_embeddings( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" - raise error_detail + raise Exception(error_detail) class GenerateCompletionForm(BaseModel): @@ -674,8 +664,7 @@ async def generate_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - - if url_idx == None: + if url_idx is None: model = form_data.model if ":" not in model: @@ -713,6 +702,18 @@ class GenerateChatCompletionForm(BaseModel): keep_alive: Optional[Union[int, str]] = None +def get_ollama_url(url_idx: Optional[int], model: str): + if url_idx is None: + if model not in app.state.MODELS: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), + ) + url_idx = random.choice(app.state.MODELS[model]["urls"]) + url = app.state.config.OLLAMA_BASE_URLS[url_idx] + return url + + @app.post("/api/chat") @app.post("/api/chat/{url_idx}") async def generate_chat_completion( @@ -720,12 +721,7 @@ async def generate_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - - log.debug( - "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format( - form_data.model_dump_json(exclude_none=True).encode() - ) - ) + log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=") payload = { **form_data.model_dump(exclude_none=True, exclude=["metadata"]), @@ -740,185 +736,21 @@ async def generate_chat_completion( if model_info.base_model_id: payload["model"] = model_info.base_model_id - model_info.params = model_info.params.model_dump() + params = model_info.params.model_dump() - if model_info.params: + if params: if payload.get("options") is None: payload["options"] = {} - if ( - model_info.params.get("mirostat", None) - and payload["options"].get("mirostat") is None - ): - payload["options"]["mirostat"] = model_info.params.get("mirostat", None) - - if ( - model_info.params.get("mirostat_eta", None) - and payload["options"].get("mirostat_eta") is None - ): - payload["options"]["mirostat_eta"] = model_info.params.get( - "mirostat_eta", None - ) - - if ( - model_info.params.get("mirostat_tau", None) - and payload["options"].get("mirostat_tau") is None - ): - payload["options"]["mirostat_tau"] = model_info.params.get( - "mirostat_tau", None - ) - - if ( - model_info.params.get("num_ctx", None) - and payload["options"].get("num_ctx") is None - ): - payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None) - - if ( - model_info.params.get("num_batch", None) - and payload["options"].get("num_batch") is None - ): - payload["options"]["num_batch"] = model_info.params.get( - "num_batch", None - ) - - if ( - model_info.params.get("num_keep", None) - and payload["options"].get("num_keep") is None - ): - payload["options"]["num_keep"] = model_info.params.get("num_keep", None) - - if ( - model_info.params.get("repeat_last_n", None) - and payload["options"].get("repeat_last_n") is None - ): - payload["options"]["repeat_last_n"] = model_info.params.get( - "repeat_last_n", None - ) - - if ( - model_info.params.get("frequency_penalty", None) - and payload["options"].get("frequency_penalty") is None - ): - payload["options"]["repeat_penalty"] = model_info.params.get( - "frequency_penalty", None - ) - - if ( - model_info.params.get("temperature", None) is not None - and payload["options"].get("temperature") is None - ): - payload["options"]["temperature"] = model_info.params.get( - "temperature", None - ) - - if ( - model_info.params.get("seed", None) is not None - and payload["options"].get("seed") is None - ): - payload["options"]["seed"] = model_info.params.get("seed", None) - - if ( - model_info.params.get("stop", None) - and payload["options"].get("stop") is None - ): - payload["options"]["stop"] = ( - [ - bytes(stop, "utf-8").decode("unicode_escape") - for stop in model_info.params["stop"] - ] - if model_info.params.get("stop", None) - else None - ) - - if ( - model_info.params.get("tfs_z", None) - and payload["options"].get("tfs_z") is None - ): - payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None) - - if ( - model_info.params.get("max_tokens", None) - and payload["options"].get("max_tokens") is None - ): - payload["options"]["num_predict"] = model_info.params.get( - "max_tokens", None - ) - - if ( - model_info.params.get("top_k", None) - and payload["options"].get("top_k") is None - ): - payload["options"]["top_k"] = model_info.params.get("top_k", None) - - if ( - model_info.params.get("top_p", None) - and payload["options"].get("top_p") is None - ): - payload["options"]["top_p"] = model_info.params.get("top_p", None) - - if ( - model_info.params.get("min_p", None) - and payload["options"].get("min_p") is None - ): - payload["options"]["min_p"] = model_info.params.get("min_p", None) - - if ( - model_info.params.get("use_mmap", None) - and payload["options"].get("use_mmap") is None - ): - payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None) - - if ( - model_info.params.get("use_mlock", None) - and payload["options"].get("use_mlock") is None - ): - payload["options"]["use_mlock"] = model_info.params.get( - "use_mlock", None - ) - - if ( - model_info.params.get("num_thread", None) - and payload["options"].get("num_thread") is None - ): - payload["options"]["num_thread"] = model_info.params.get( - "num_thread", None - ) - - system = model_info.params.get("system", None) - if system: - system = prompt_template( - system, - **( - { - "user_name": user.name, - "user_location": ( - user.info.get("location") if user.info else None - ), - } - if user - else {} - ), + payload["options"] = apply_model_params_to_body_ollama( + params, payload["options"] ) + payload = apply_model_system_prompt_to_body(params, payload, user) - if payload.get("messages"): - payload["messages"] = add_or_update_system_message( - system, payload["messages"] - ) + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" - if url_idx == None: - if ":" not in payload["model"]: - payload["model"] = f"{payload['model']}:latest" - - if payload["model"] in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") log.debug(payload) @@ -952,83 +784,28 @@ async def generate_openai_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - form_data = OpenAIChatCompletionForm(**form_data) - payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])} - + completion_form = OpenAIChatCompletionForm(**form_data) + payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])} if "metadata" in payload: del payload["metadata"] - model_id = form_data.model + model_id = completion_form.model model_info = Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id - model_info.params = model_info.params.model_dump() + params = model_info.params.model_dump() - if model_info.params: - payload["temperature"] = model_info.params.get("temperature", None) - payload["top_p"] = model_info.params.get("top_p", None) - payload["max_tokens"] = model_info.params.get("max_tokens", None) - payload["frequency_penalty"] = model_info.params.get( - "frequency_penalty", None - ) - payload["seed"] = model_info.params.get("seed", None) - payload["stop"] = ( - [ - bytes(stop, "utf-8").decode("unicode_escape") - for stop in model_info.params["stop"] - ] - if model_info.params.get("stop", None) - else None - ) + if params: + payload = apply_model_params_to_body_openai(params, payload) + payload = apply_model_system_prompt_to_body(params, payload, user) - system = model_info.params.get("system", None) + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" - if system: - system = prompt_template( - system, - **( - { - "user_name": user.name, - "user_location": ( - user.info.get("location") if user.info else None - ), - } - if user - else {} - ), - ) - # Check if the payload already has a system message - # If not, add a system message to the payload - if payload.get("messages"): - for message in payload["messages"]: - if message.get("role") == "system": - message["content"] = system + message["content"] - break - else: - payload["messages"].insert( - 0, - { - "role": "system", - "content": system, - }, - ) - - if url_idx == None: - if ":" not in payload["model"]: - payload["model"] = f"{payload['model']}:latest" - - if payload["model"] in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") return await post_streaming_url( @@ -1044,7 +821,7 @@ async def get_openai_models( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - if url_idx == None: + if url_idx is None: models = await get_all_models() if app.state.config.ENABLE_MODEL_FILTER: @@ -1099,7 +876,7 @@ async def get_openai_models( res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" - except: + except Exception: error_detail = f"Ollama: {e}" raise HTTPException( @@ -1125,7 +902,6 @@ def parse_huggingface_url(hf_url): path_components = parsed_url.path.split("/") # Extract the desired output - user_repo = "/".join(path_components[1:3]) model_file = path_components[-1] return model_file @@ -1190,7 +966,6 @@ async def download_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - allowed_hosts = ["https://huggingface.co/", "https://github.com/"] if not any(form_data.url.startswith(host) for host in allowed_hosts): @@ -1199,7 +974,7 @@ async def download_model( detail="Invalid file_url. Only URLs from allowed hosts are permitted.", ) - if url_idx == None: + if url_idx is None: url_idx = 0 url = app.state.config.OLLAMA_BASE_URLS[url_idx] @@ -1222,7 +997,7 @@ def upload_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): - if url_idx == None: + if url_idx is None: url_idx = 0 ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 948ee7c46a..50de53a537 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -17,7 +17,10 @@ from utils.utils import ( get_verified_user, get_admin_user, ) -from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body +from utils.misc import ( + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, +) from config import ( SRC_LOG_LEVELS, @@ -368,7 +371,7 @@ async def generate_chat_completion( payload["model"] = model_info.base_model_id params = model_info.params.model_dump() - payload = apply_model_params_to_body(params, payload) + payload = apply_model_params_to_body_openai(params, payload) payload = apply_model_system_prompt_to_body(params, payload, user) model = app.state.MODELS[payload.get("model")] diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 3c387842d8..dddf3fbb2a 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -22,7 +22,7 @@ from apps.webui.utils import load_function_module_by_id from utils.misc import ( openai_chat_chunk_message_template, openai_chat_completion_message_template, - apply_model_params_to_body, + apply_model_params_to_body_openai, apply_model_system_prompt_to_body, ) @@ -291,7 +291,7 @@ async def generate_function_chat_completion(form_data, user): form_data["model"] = model_info.base_model_id params = model_info.params.model_dump() - form_data = apply_model_params_to_body(params, form_data) + form_data = apply_model_params_to_body_openai(params, form_data) form_data = apply_model_system_prompt_to_body(params, form_data, user) pipe_id = get_pipe_id(form_data) diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 25dd4dd5b6..9de19d3f69 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -2,7 +2,7 @@ from pathlib import Path import hashlib import re from datetime import timedelta -from typing import Optional, List, Tuple +from typing import Optional, List, Tuple, Callable import uuid import time @@ -135,10 +135,21 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di # inplace function: form_data is modified -def apply_model_params_to_body(params: dict, form_data: dict) -> dict: +def apply_model_params_to_body( + params: dict, form_data: dict, mappings: dict[str, Callable] +) -> dict: if not params: return form_data + for key, cast_func in mappings.items(): + if (value := params.get(key)) is not None: + form_data[key] = cast_func(value) + + return form_data + + +# inplace function: form_data is modified +def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: mappings = { "temperature": float, "top_p": int, @@ -147,10 +158,39 @@ def apply_model_params_to_body(params: dict, form_data: dict) -> dict: "seed": lambda x: x, "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], } + return apply_model_params_to_body(params, form_data, mappings) - for key, cast_func in mappings.items(): - if (value := params.get(key)) is not None: - form_data[key] = cast_func(value) + +def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: + opts = [ + "temperature", + "top_p", + "seed", + "mirostat", + "mirostat_eta", + "mirostat_tau", + "num_ctx", + "num_batch", + "num_keep", + "repeat_last_n", + "tfs_z", + "top_k", + "min_p", + "use_mmap", + "use_mlock", + "num_thread", + ] + mappings = {i: lambda x: x for i in opts} + form_data = apply_model_params_to_body(params, form_data, mappings) + + name_differences = { + "max_tokens": "num_predict", + "frequency_penalty": "repeat_penalty", + } + + for key, value in name_differences.items(): + if (param := params.get(key, None)) is not None: + form_data[value] = param return form_data