diff --git a/CHANGELOG.md b/CHANGELOG.md index c1c6596a67..d57ba400c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,24 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.1.111] - 2024-03-10 + +### Added + +- 🛡️ **Model Whitelisting**: Admins now have the ability to whitelist models for users with the 'user' role. +- 🔄 **Update All Models**: Added a convenient button to update all models at once. +- 📄 **Toggle PDF OCR**: Users can now toggle PDF OCR option for improved parsing performance. +- 🎨 **DALL-E Integration**: Introduced DALL-E integration for image generation alongside automatic1111. +- 🛠️ **RAG API Refactoring**: Refactored RAG logic and exposed its API, with additional documentation to follow. + +### Fixed + +- 🔒 **Max Token Settings**: Added max token settings for anthropic/claude-3-sonnet-20240229 (Issue #1094). +- 🔧 **Misalignment Issue**: Corrected misalignment of Edit and Delete Icons when Chat Title is Empty (Issue #1104). +- 🔄 **Context Loss Fix**: Resolved RAG losing context on model response regeneration with Groq models via API key (Issue #1105). +- 📁 **File Handling Bug**: Addressed File Not Found Notification when Dropping a Conversation Element (Issue #1098). +- 🖱️ **Dragged File Styling**: Fixed dragged file layover styling issue. + ## [0.1.110] - 2024-03-06 ### Added diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index e902bea270..375ed3f121 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -179,20 +179,26 @@ def merge_models_lists(model_lists): async def get_all_models(): print("get_all_models") - tasks = [ - fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx]) - for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS) - ] - responses = await asyncio.gather(*tasks) - responses = list(filter(lambda x: x is not None and "error" not in x, responses)) - models = { - "data": merge_models_lists( - list(map(lambda response: response["data"], responses)) - ) - } - app.state.MODELS = {model["id"]: model for model in models["data"]} - return models + if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "": + models = {"data": []} + else: + tasks = [ + fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx]) + for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS) + ] + responses = await asyncio.gather(*tasks) + responses = list( + filter(lambda x: x is not None and "error" not in x, responses) + ) + models = { + "data": merge_models_lists( + list(map(lambda response: response["data"], responses)) + ) + } + app.state.MODELS = {model["id"]: model for model in models["data"]} + + return models @app.get("/models") diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 6781a9a147..b21724cc9c 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -77,6 +77,7 @@ from constants import ERROR_MESSAGES app = FastAPI() +app.state.PDF_EXTRACT_IMAGES = False app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.RAG_TEMPLATE = RAG_TEMPLATE @@ -184,12 +185,15 @@ async def update_embedding_model( } -@app.get("/chunk") -async def get_chunk_params(user=Depends(get_admin_user)): +@app.get("/config") +async def get_rag_config(user=Depends(get_admin_user)): return { "status": True, - "chunk_size": app.state.CHUNK_SIZE, - "chunk_overlap": app.state.CHUNK_OVERLAP, + "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, + "chunk": { + "chunk_size": app.state.CHUNK_SIZE, + "chunk_overlap": app.state.CHUNK_OVERLAP, + }, } @@ -198,17 +202,24 @@ class ChunkParamUpdateForm(BaseModel): chunk_overlap: int -@app.post("/chunk/update") -async def update_chunk_params( - form_data: ChunkParamUpdateForm, user=Depends(get_admin_user) -): - app.state.CHUNK_SIZE = form_data.chunk_size - app.state.CHUNK_OVERLAP = form_data.chunk_overlap +class ConfigUpdateForm(BaseModel): + pdf_extract_images: bool + chunk: ChunkParamUpdateForm + + +@app.post("/config/update") +async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): + app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images + app.state.CHUNK_SIZE = form_data.chunk.chunk_size + app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap return { "status": True, - "chunk_size": app.state.CHUNK_SIZE, - "chunk_overlap": app.state.CHUNK_OVERLAP, + "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, + "chunk": { + "chunk_size": app.state.CHUNK_SIZE, + "chunk_overlap": app.state.CHUNK_OVERLAP, + }, } @@ -364,7 +375,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str): ] if file_ext == "pdf": - loader = PyPDFLoader(file_path, extract_images=True) + loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES) elif file_ext == "csv": loader = CSVLoader(file_path) elif file_ext == "rst": diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 91b07e0aa2..b2da7d90c3 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -95,3 +95,89 @@ def rag_template(template: str, context: str, query: str): template = re.sub(r"\[query\]", query, template) return template + + +def rag_messages(docs, messages, template, k, embedding_function): + print(docs) + + last_user_message_idx = None + for i in range(len(messages) - 1, -1, -1): + if messages[i]["role"] == "user": + last_user_message_idx = i + break + + user_message = messages[last_user_message_idx] + + if isinstance(user_message["content"], list): + # Handle list content input + content_type = "list" + query = "" + for content_item in user_message["content"]: + if content_item["type"] == "text": + query = content_item["text"] + break + elif isinstance(user_message["content"], str): + # Handle text content input + content_type = "text" + query = user_message["content"] + else: + # Fallback in case the input does not match expected types + content_type = None + query = "" + + relevant_contexts = [] + + for doc in docs: + context = None + + try: + if doc["type"] == "collection": + context = query_collection( + collection_names=doc["collection_names"], + query=query, + k=k, + embedding_function=embedding_function, + ) + else: + context = query_doc( + collection_name=doc["collection_name"], + query=query, + k=k, + embedding_function=embedding_function, + ) + except Exception as e: + print(e) + context = None + + relevant_contexts.append(context) + + context_string = "" + for context in relevant_contexts: + if context: + context_string += " ".join(context["documents"][0]) + "\n" + + ra_content = rag_template( + template=template, + context=context_string, + query=query, + ) + + if content_type == "list": + new_content = [] + for content_item in user_message["content"]: + if content_item["type"] == "text": + # Update the text item's content with ra_content + new_content.append({"type": "text", "text": ra_content}) + else: + # Keep other types of content as they are + new_content.append(content_item) + new_user_message = {**user_message, "content": new_content} + else: + new_user_message = { + **user_message, + "content": ra_content, + } + + messages[last_user_message_idx] = new_user_message + + return messages diff --git a/backend/config.py b/backend/config.py index 019e44e010..831371bb7e 100644 --- a/backend/config.py +++ b/backend/config.py @@ -209,10 +209,6 @@ OLLAMA_API_BASE_URL = os.environ.get( OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "") -if ENV == "prod": - if OLLAMA_BASE_URL == "/ollama": - OLLAMA_BASE_URL = "http://host.docker.internal:11434" - if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "": OLLAMA_BASE_URL = ( @@ -221,6 +217,11 @@ if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "": else OLLAMA_API_BASE_URL ) +if ENV == "prod": + if OLLAMA_BASE_URL == "/ollama": + OLLAMA_BASE_URL = "http://host.docker.internal:11434" + + OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "") OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL @@ -234,8 +235,6 @@ OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")] OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") -if OPENAI_API_KEY == "": - OPENAI_API_KEY = "none" if OPENAI_API_BASE_URL == "": OPENAI_API_BASE_URL = "https://api.openai.com/v1" diff --git a/backend/data/config.json b/backend/data/config.json index 1b5971005b..d3ada59c91 100644 --- a/backend/data/config.json +++ b/backend/data/config.json @@ -1,4 +1,5 @@ { + "version": "0.0.1", "ui": { "prompt_suggestions": [ { diff --git a/backend/main.py b/backend/main.py index c7523ec62c..2532271824 100644 --- a/backend/main.py +++ b/backend/main.py @@ -28,7 +28,7 @@ from typing import List from utils.utils import get_admin_user -from apps.rag.utils import query_doc, query_collection, rag_template +from apps.rag.utils import rag_messages from config import ( WEBUI_NAME, @@ -60,19 +60,6 @@ app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST origins = ["*"] -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -@app.on_event("startup") -async def on_startup(): - await litellm_app_startup() - class RAGMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): @@ -91,98 +78,33 @@ class RAGMiddleware(BaseHTTPMiddleware): # Example: Add a new key-value pair or modify existing ones # data["modified"] = True # Example modification if "docs" in data: - docs = data["docs"] - print(docs) - last_user_message_idx = None - for i in range(len(data["messages"]) - 1, -1, -1): - if data["messages"][i]["role"] == "user": - last_user_message_idx = i - break - - user_message = data["messages"][last_user_message_idx] - - if isinstance(user_message["content"], list): - # Handle list content input - content_type = "list" - query = "" - for content_item in user_message["content"]: - if content_item["type"] == "text": - query = content_item["text"] - break - elif isinstance(user_message["content"], str): - # Handle text content input - content_type = "text" - query = user_message["content"] - else: - # Fallback in case the input does not match expected types - content_type = None - query = "" - - relevant_contexts = [] - - for doc in docs: - context = None - - try: - if doc["type"] == "collection": - context = query_collection( - collection_names=doc["collection_names"], - query=query, - k=rag_app.state.TOP_K, - embedding_function=rag_app.state.sentence_transformer_ef, - ) - else: - context = query_doc( - collection_name=doc["collection_name"], - query=query, - k=rag_app.state.TOP_K, - embedding_function=rag_app.state.sentence_transformer_ef, - ) - except Exception as e: - print(e) - context = None - - relevant_contexts.append(context) - - context_string = "" - for context in relevant_contexts: - if context: - context_string += " ".join(context["documents"][0]) + "\n" - - ra_content = rag_template( - template=rag_app.state.RAG_TEMPLATE, - context=context_string, - query=query, + data = {**data} + data["messages"] = rag_messages( + data["docs"], + data["messages"], + rag_app.state.RAG_TEMPLATE, + rag_app.state.TOP_K, + rag_app.state.sentence_transformer_ef, ) - - if content_type == "list": - new_content = [] - for content_item in user_message["content"]: - if content_item["type"] == "text": - # Update the text item's content with ra_content - new_content.append({"type": "text", "text": ra_content}) - else: - # Keep other types of content as they are - new_content.append(content_item) - new_user_message = {**user_message, "content": new_content} - else: - new_user_message = { - **user_message, - "content": ra_content, - } - - data["messages"][last_user_message_idx] = new_user_message del data["docs"] print(data["messages"]) modified_body_bytes = json.dumps(data).encode("utf-8") - # Create a new request with the modified body - scope = request.scope - scope["body"] = modified_body_bytes - request = Request(scope, receive=lambda: self._receive(modified_body_bytes)) + # Replace the request body with the modified one + request._body = modified_body_bytes + + # Set custom header to ensure content-length matches new body length + request.headers.__dict__["_list"] = [ + (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), + *[ + (k, v) + for k, v in request.headers.raw + if k.lower() != b"content-length" + ], + ] response = await call_next(request) return response @@ -194,6 +116,15 @@ class RAGMiddleware(BaseHTTPMiddleware): app.add_middleware(RAGMiddleware) +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + @app.middleware("http") async def check_url(request: Request, call_next): start_time = int(time.time()) @@ -204,6 +135,11 @@ async def check_url(request: Request, call_next): return response +@app.on_event("startup") +async def on_startup(): + await litellm_app_startup() + + app.mount("/api/v1", webui_app) app.mount("/litellm/api", litellm_app) diff --git a/package.json b/package.json index 66a893ad3e..29ff83915b 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.1.110", + "version": "0.1.111", "private": true, "scripts": { "dev": "vite dev --host", diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index 6dcfbbe7d8..668fe227be 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -1,9 +1,9 @@ import { RAG_API_BASE_URL } from '$lib/constants'; -export const getChunkParams = async (token: string) => { +export const getRAGConfig = async (token: string) => { let error = null; - const res = await fetch(`${RAG_API_BASE_URL}/chunk`, { + const res = await fetch(`${RAG_API_BASE_URL}/config`, { method: 'GET', headers: { 'Content-Type': 'application/json', @@ -27,18 +27,27 @@ export const getChunkParams = async (token: string) => { return res; }; -export const updateChunkParams = async (token: string, size: number, overlap: number) => { +type ChunkConfigForm = { + chunk_size: number; + chunk_overlap: number; +}; + +type RAGConfigForm = { + pdf_extract_images: boolean; + chunk: ChunkConfigForm; +}; + +export const updateRAGConfig = async (token: string, payload: RAGConfigForm) => { let error = null; - const res = await fetch(`${RAG_API_BASE_URL}/chunk/update`, { + const res = await fetch(`${RAG_API_BASE_URL}/config/update`, { method: 'POST', headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ - chunk_size: size, - chunk_overlap: overlap + ...payload }) }) .then(async (res) => { diff --git a/src/lib/components/chat/Settings/Models.svelte b/src/lib/components/chat/Settings/Models.svelte index 92e4afd240..6ce14f9e7b 100644 --- a/src/lib/components/chat/Settings/Models.svelte +++ b/src/lib/components/chat/Settings/Models.svelte @@ -14,6 +14,7 @@ import { splitStream } from '$lib/utils'; import { onMount, getContext } from 'svelte'; import { addLiteLLMModel, deleteLiteLLMModel, getLiteLLMModelInfo } from '$lib/apis/litellm'; + import Tooltip from '$lib/components/common/Tooltip.svelte'; const i18n = getContext('i18n'); @@ -39,6 +40,10 @@ let OLLAMA_URLS = []; let selectedOllamaUrlIdx: string | null = null; + + let updateModelId = null; + let updateProgress = null; + let showExperimentalOllama = false; let ollamaVersion = ''; const MAX_PARALLEL_DOWNLOADS = 3; @@ -63,6 +68,71 @@ let deleteModelTag = ''; + const updateModelsHandler = async () => { + for (const model of $models.filter( + (m) => + m.size != null && + (selectedOllamaUrlIdx === null ? true : (m?.urls ?? []).includes(selectedOllamaUrlIdx)) + )) { + console.log(model); + + updateModelId = model.id; + const res = await pullModel(localStorage.token, model.id, selectedOllamaUrlIdx).catch( + (error) => { + toast.error(error); + return null; + } + ); + + if (res) { + const reader = res.body + .pipeThrough(new TextDecoderStream()) + .pipeThrough(splitStream('\n')) + .getReader(); + + while (true) { + try { + const { value, done } = await reader.read(); + if (done) break; + + let lines = value.split('\n'); + + for (const line of lines) { + if (line !== '') { + let data = JSON.parse(line); + + console.log(data); + if (data.error) { + throw data.error; + } + if (data.detail) { + throw data.detail; + } + if (data.status) { + if (data.digest) { + updateProgress = 0; + if (data.completed) { + updateProgress = Math.round((data.completed / data.total) * 1000) / 10; + } else { + updateProgress = 100; + } + } else { + toast.success(data.status); + } + } + } + } + } catch (error) { + console.log(error); + } + } + } + } + + updateModelId = null; + updateProgress = null; + }; + const pullModelHandler = async () => { const sanitizedModelTag = modelTag.trim(); if (modelDownloadStatus[sanitizedModelTag]) { @@ -389,7 +459,7 @@ return []; }); - if (OLLAMA_URLS.length > 1) { + if (OLLAMA_URLS.length > 0) { selectedOllamaUrlIdx = 0; } @@ -404,18 +474,51 @@