diff --git a/CHANGELOG.md b/CHANGELOG.md index a7913785d3..86455f0504 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,34 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.3.9] - 2024-07-17 + +### Added + +- **📁 Files Chat Controls**: We've reverted to the old file handling behavior where uploaded files are always included. You can now manage files directly within the chat controls section, giving you the ability to remove files as needed. +- **🔧 "Action" Function Support**: Introducing a new "Action" function to write custom buttons to the message toolbar. This feature enables more interactive messaging, with documentation coming soon. +- **📜 Citations Handling**: For newly uploaded files in documents workspace, citations will now display the actual filename. Additionally, you can click on these filenames to open the file in a new tab for easier access. +- **🛠️ Event Emitter and Call Updates**: Enhanced 'event_emitter' to allow message replacement and 'event_call' to support text input for Tools and Functions. Detailed documentation will be provided shortly. +- **🎨 Styling Refactor**: Various styling updates for a cleaner and more cohesive user interface. +- **🌐 Enhanced Translations**: Improved translations for Catalan, Ukrainian, and Brazilian Portuguese. + +### Fixed + +- **🔧 Chat Controls Priority**: Resolved an issue where Chat Controls values were being overridden by model information parameters. The priority is now Chat Controls, followed by Global Settings, then Model Settings. +- **🪲 Debug Logs**: Fixed an issue where debug logs were not being logged properly. +- **🔑 Automatic1111 Auth Key**: The auth key for Automatic1111 is no longer required. +- **📝 Title Generation**: Ensured that the title generation runs only once, even when multiple models are in a chat. +- **✅ Boolean Values in Params**: Added support for boolean values in parameters. +- **🖼️ Files Overlay Styling**: Fixed the styling issue with the files overlay. + +### Changed + +- **⬆️ Dependency Updates** + - Upgraded 'pydantic' from version 2.7.1 to 2.8.2. + - Upgraded 'sqlalchemy' from version 2.0.30 to 2.0.31. + - Upgraded 'unstructured' from version 0.14.9 to 0.14.10. + - Upgraded 'chromadb' from version 0.5.3 to 0.5.4. + ## [0.3.8] - 2024-07-09 ### Added diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index 24542ee930..9ae0ad67b7 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -421,7 +421,7 @@ def save_url_image(url): @app.post("/generations") -def generate_image( +async def image_generations( form_data: GenerateImageForm, user=Depends(get_verified_user), ): diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index fd4ba7b060..0a36d4c2be 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -728,8 +728,10 @@ async def generate_chat_completion( ) payload = { - **form_data.model_dump(exclude_none=True), + **form_data.model_dump(exclude_none=True, exclude=["metadata"]), } + if "metadata" in payload: + del payload["metadata"] model_id = form_data.model model_info = Models.get_model_by_id(model_id) @@ -741,52 +743,85 @@ async def generate_chat_completion( model_info.params = model_info.params.model_dump() if model_info.params: - payload["options"] = {} + if payload.get("options") is None: + payload["options"] = {} - if model_info.params.get("mirostat", None): + if ( + model_info.params.get("mirostat", None) + and payload["options"].get("mirostat") is None + ): payload["options"]["mirostat"] = model_info.params.get("mirostat", None) - if model_info.params.get("mirostat_eta", None): + if ( + model_info.params.get("mirostat_eta", None) + and payload["options"].get("mirostat_eta") is None + ): payload["options"]["mirostat_eta"] = model_info.params.get( "mirostat_eta", None ) - if model_info.params.get("mirostat_tau", None): - + if ( + model_info.params.get("mirostat_tau", None) + and payload["options"].get("mirostat_tau") is None + ): payload["options"]["mirostat_tau"] = model_info.params.get( "mirostat_tau", None ) - if model_info.params.get("num_ctx", None): + if ( + model_info.params.get("num_ctx", None) + and payload["options"].get("num_ctx") is None + ): payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None) - if model_info.params.get("num_batch", None): + if ( + model_info.params.get("num_batch", None) + and payload["options"].get("num_batch") is None + ): payload["options"]["num_batch"] = model_info.params.get( "num_batch", None ) - if model_info.params.get("num_keep", None): + if ( + model_info.params.get("num_keep", None) + and payload["options"].get("num_keep") is None + ): payload["options"]["num_keep"] = model_info.params.get("num_keep", None) - if model_info.params.get("repeat_last_n", None): + if ( + model_info.params.get("repeat_last_n", None) + and payload["options"].get("repeat_last_n") is None + ): payload["options"]["repeat_last_n"] = model_info.params.get( "repeat_last_n", None ) - if model_info.params.get("frequency_penalty", None): + if ( + model_info.params.get("frequency_penalty", None) + and payload["options"].get("frequency_penalty") is None + ): payload["options"]["repeat_penalty"] = model_info.params.get( "frequency_penalty", None ) - if model_info.params.get("temperature", None) is not None: + if ( + model_info.params.get("temperature", None) + and payload["options"].get("temperature") is None + ): payload["options"]["temperature"] = model_info.params.get( "temperature", None ) - if model_info.params.get("seed", None): + if ( + model_info.params.get("seed", None) + and payload["options"].get("seed") is None + ): payload["options"]["seed"] = model_info.params.get("seed", None) - if model_info.params.get("stop", None): + if ( + model_info.params.get("stop", None) + and payload["options"].get("stop") is None + ): payload["options"]["stop"] = ( [ bytes(stop, "utf-8").decode("unicode_escape") @@ -796,37 +831,56 @@ async def generate_chat_completion( else None ) - if model_info.params.get("tfs_z", None): + if ( + model_info.params.get("tfs_z", None) + and payload["options"].get("tfs_z") is None + ): payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None) - if model_info.params.get("max_tokens", None): + if ( + model_info.params.get("max_tokens", None) + and payload["options"].get("max_tokens") is None + ): payload["options"]["num_predict"] = model_info.params.get( "max_tokens", None ) - if model_info.params.get("top_k", None): + if ( + model_info.params.get("top_k", None) + and payload["options"].get("top_k") is None + ): payload["options"]["top_k"] = model_info.params.get("top_k", None) - if model_info.params.get("top_p", None): + if ( + model_info.params.get("top_p", None) + and payload["options"].get("top_p") is None + ): payload["options"]["top_p"] = model_info.params.get("top_p", None) - if model_info.params.get("use_mmap", None): + if ( + model_info.params.get("use_mmap", None) + and payload["options"].get("use_mmap") is None + ): payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None) - if model_info.params.get("use_mlock", None): + if ( + model_info.params.get("use_mlock", None) + and payload["options"].get("use_mlock") is None + ): payload["options"]["use_mlock"] = model_info.params.get( "use_mlock", None ) - if model_info.params.get("num_thread", None): + if ( + model_info.params.get("num_thread", None) + and payload["options"].get("num_thread") is None + ): payload["options"]["num_thread"] = model_info.params.get( "num_thread", None ) system = model_info.params.get("system", None) if system: - # Check if the payload already has a system message - # If not, add a system message to the payload system = prompt_template( system, **( @@ -893,10 +947,10 @@ async def generate_openai_chat_completion( user=Depends(get_verified_user), ): form_data = OpenAIChatCompletionForm(**form_data) + payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])} - payload = { - **form_data.model_dump(exclude_none=True), - } + if "metadata" in payload: + del payload["metadata"] model_id = form_data.model model_info = Models.get_model_by_id(model_id) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 7c67c40ae9..6c2906095a 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -21,6 +21,7 @@ from utils.utils import ( get_admin_user, ) from utils.task import prompt_template +from utils.misc import add_or_update_system_message from config import ( SRC_LOG_LEVELS, @@ -357,6 +358,8 @@ async def generate_chat_completion( ): idx = 0 payload = {**form_data} + if "metadata" in payload: + del payload["metadata"] model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) @@ -368,24 +371,33 @@ async def generate_chat_completion( model_info.params = model_info.params.model_dump() if model_info.params: - if model_info.params.get("temperature", None) is not None: + if ( + model_info.params.get("temperature", None) + and payload.get("temperature") is None + ): payload["temperature"] = float(model_info.params.get("temperature")) - if model_info.params.get("top_p", None): + if model_info.params.get("top_p", None) and payload.get("top_p") is None: payload["top_p"] = int(model_info.params.get("top_p", None)) - if model_info.params.get("max_tokens", None): + if ( + model_info.params.get("max_tokens", None) + and payload.get("max_tokens") is None + ): payload["max_tokens"] = int(model_info.params.get("max_tokens", None)) - if model_info.params.get("frequency_penalty", None): + if ( + model_info.params.get("frequency_penalty", None) + and payload.get("frequency_penalty") is None + ): payload["frequency_penalty"] = int( model_info.params.get("frequency_penalty", None) ) - if model_info.params.get("seed", None): + if model_info.params.get("seed", None) and payload.get("seed") is None: payload["seed"] = model_info.params.get("seed", None) - if model_info.params.get("stop", None): + if model_info.params.get("stop", None) and payload.get("stop") is None: payload["stop"] = ( [ bytes(stop, "utf-8").decode("unicode_escape") @@ -410,21 +422,10 @@ async def generate_chat_completion( else {} ), ) - # Check if the payload already has a system message - # If not, add a system message to the payload if payload.get("messages"): - for message in payload["messages"]: - if message.get("role") == "system": - message["content"] = system + message["content"] - break - else: - payload["messages"].insert( - 0, - { - "role": "system", - "content": system, - }, - ) + payload["messages"] = add_or_update_system_message( + system, payload["messages"] + ) else: pass diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index c0f8a09ed2..8631846ec7 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -930,7 +930,9 @@ def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)): ) -def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: +def store_data_in_vector_db( + data, collection_name, metadata: Optional[dict] = None, overwrite: bool = False +) -> bool: text_splitter = RecursiveCharacterTextSplitter( chunk_size=app.state.config.CHUNK_SIZE, @@ -942,7 +944,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b if len(docs) > 0: log.info(f"store_data_in_vector_db {docs}") - return store_docs_in_vector_db(docs, collection_name, overwrite), None + return store_docs_in_vector_db(docs, collection_name, metadata, overwrite), None else: raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) @@ -956,14 +958,16 @@ def store_text_in_vector_db( add_start_index=True, ) docs = text_splitter.create_documents([text], metadatas=[metadata]) - return store_docs_in_vector_db(docs, collection_name, overwrite) + return store_docs_in_vector_db(docs, collection_name, overwrite=overwrite) -def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: +def store_docs_in_vector_db( + docs, collection_name, metadata: Optional[dict] = None, overwrite: bool = False +) -> bool: log.info(f"store_docs_in_vector_db {docs} {collection_name}") texts = [doc.page_content for doc in docs] - metadatas = [doc.metadata for doc in docs] + metadatas = [{**doc.metadata, **(metadata if metadata else {})} for doc in docs] # ChromaDB does not like datetime formats # for meta-data so convert them to string. @@ -1237,13 +1241,21 @@ def process_doc( data = loader.load() try: - result = store_data_in_vector_db(data, collection_name) + result = store_data_in_vector_db( + data, + collection_name, + { + "file_id": form_data.file_id, + "name": file.meta.get("name", file.filename), + }, + ) if result: return { "status": True, "collection_name": collection_name, "known_type": known_type, + "filename": file.meta.get("name", file.filename), } except Exception as e: raise HTTPException( diff --git a/backend/apps/socket/main.py b/backend/apps/socket/main.py index 123ff31cd2..18ce7a6072 100644 --- a/backend/apps/socket/main.py +++ b/backend/apps/socket/main.py @@ -137,3 +137,34 @@ async def disconnect(sid): await sio.emit("user-count", {"count": len(USER_POOL)}) else: print(f"Unknown session ID {sid} disconnected") + + +async def get_event_emitter(request_info): + async def __event_emitter__(event_data): + await sio.emit( + "chat-events", + { + "chat_id": request_info["chat_id"], + "message_id": request_info["message_id"], + "data": event_data, + }, + to=request_info["session_id"], + ) + + return __event_emitter__ + + +async def get_event_call(request_info): + async def __event_call__(event_data): + response = await sio.call( + "chat-events", + { + "chat_id": request_info["chat_id"], + "message_id": request_info["message_id"], + "data": event_data, + }, + to=request_info["session_id"], + ) + return response + + return __event_call__ diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index ab28868ae4..570cad9f19 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -20,7 +20,6 @@ from apps.webui.routers import ( ) from apps.webui.models.functions import Functions from apps.webui.models.models import Models - from apps.webui.utils import load_function_module_by_id from utils.misc import stream_message_template @@ -48,12 +47,14 @@ from config import ( OAUTH_PICTURE_CLAIM, ) +from apps.socket.main import get_event_call, get_event_emitter + import inspect import uuid import time import json -from typing import Iterator, Generator +from typing import Iterator, Generator, Optional from pydantic import BaseModel app = FastAPI() @@ -164,6 +165,10 @@ async def get_pipe_models(): f"{function_module.name}{manifold_pipe_name}" ) + pipe_flag = {"type": pipe.type} + if hasattr(function_module, "ChatValves"): + pipe_flag["valves_spec"] = function_module.ChatValves.schema() + pipe_models.append( { "id": manifold_pipe_id, @@ -171,10 +176,14 @@ async def get_pipe_models(): "object": "model", "created": pipe.created_at, "owned_by": "openai", - "pipe": {"type": pipe.type}, + "pipe": pipe_flag, } ) else: + pipe_flag = {"type": "pipe"} + if hasattr(function_module, "ChatValves"): + pipe_flag["valves_spec"] = function_module.ChatValves.schema() + pipe_models.append( { "id": pipe.id, @@ -182,7 +191,7 @@ async def get_pipe_models(): "object": "model", "created": pipe.created_at, "owned_by": "openai", - "pipe": {"type": "pipe"}, + "pipe": pipe_flag, } ) @@ -193,6 +202,27 @@ async def generate_function_chat_completion(form_data, user): model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) + metadata = None + if "metadata" in form_data: + metadata = form_data["metadata"] + del form_data["metadata"] + + __event_emitter__ = None + __event_call__ = None + __task__ = None + + if metadata: + if ( + metadata.get("session_id") + and metadata.get("chat_id") + and metadata.get("message_id") + ): + __event_emitter__ = await get_event_emitter(metadata) + __event_call__ = await get_event_call(metadata) + + if metadata.get("task"): + __task__ = metadata.get("task") + if model_info: if model_info.base_model_id: form_data["model"] = model_info.base_model_id @@ -307,6 +337,15 @@ async def generate_function_chat_completion(form_data, user): params = {**params, "__user__": __user__} + if "__event_emitter__" in sig.parameters: + params = {**params, "__event_emitter__": __event_emitter__} + + if "__event_call__" in sig.parameters: + params = {**params, "__event_call__": __event_call__} + + if "__task__" in sig.parameters: + params = {**params, "__task__": __task__} + if form_data["stream"]: async def stream_content(): diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 907576b80e..cb73da6944 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -167,6 +167,15 @@ class FunctionsTable: .all() ] + def get_global_action_functions(self) -> List[FunctionModel]: + with get_db() as db: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function) + .filter_by(type="action", is_active=True, is_global=True) + .all() + ] + def get_function_valves_by_id(self, id: str) -> Optional[dict]: with get_db() as db: diff --git a/backend/apps/webui/routers/files.py b/backend/apps/webui/routers/files.py index fffe0743c1..99fb923a12 100644 --- a/backend/apps/webui/routers/files.py +++ b/backend/apps/webui/routers/files.py @@ -58,6 +58,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): # replace filename with uuid id = str(uuid.uuid4()) + name = filename filename = f"{id}_{filename}" file_path = f"{UPLOAD_DIR}/{filename}" @@ -73,6 +74,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): "id": id, "filename": filename, "meta": { + "name": name, "content_type": file.content_type, "size": len(contents), "path": file_path, diff --git a/backend/apps/webui/utils.py b/backend/apps/webui/utils.py index 545120835c..96d2b29ebf 100644 --- a/backend/apps/webui/utils.py +++ b/backend/apps/webui/utils.py @@ -79,6 +79,8 @@ def load_function_module_by_id(function_id): return module.Pipe(), "pipe", frontmatter elif hasattr(module, "Filter"): return module.Filter(), "filter", frontmatter + elif hasattr(module, "Action"): + return module.Action(), "action", frontmatter else: raise Exception("No Function class found") except Exception as e: diff --git a/backend/constants.py b/backend/constants.py index 7c366c2224..b9c7fc430d 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -95,8 +95,8 @@ class TASKS(str, Enum): def __str__(self) -> str: return super().__str__() - DEFAULT = lambda task="": f"{task if task else 'default'}" - TITLE_GENERATION = "Title Generation" - EMOJI_GENERATION = "Emoji Generation" - QUERY_GENERATION = "Query Generation" - FUNCTION_CALLING = "Function Calling" + DEFAULT = lambda task="": f"{task if task else 'generation'}" + TITLE_GENERATION = "title_generation" + EMOJI_GENERATION = "emoji_generation" + QUERY_GENERATION = "query_generation" + FUNCTION_CALLING = "function_calling" diff --git a/backend/data/config.json b/backend/data/config.json index 6c0ad2b9fb..7c7acde917 100644 --- a/backend/data/config.json +++ b/backend/data/config.json @@ -1,7 +1,7 @@ { "version": 0, "ui": { - "default_locale": "en-US", + "default_locale": "", "prompt_suggestions": [ { "title": ["Help me study", "vocabulary for a college entrance exam"], diff --git a/backend/main.py b/backend/main.py index 89252e1641..62f07a868c 100644 --- a/backend/main.py +++ b/backend/main.py @@ -29,7 +29,7 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.responses import StreamingResponse, Response, RedirectResponse -from apps.socket.main import sio, app as socket_app +from apps.socket.main import sio, app as socket_app, get_event_emitter, get_event_call from apps.ollama.main import ( app as ollama_app, get_all_models as get_ollama_models, @@ -317,7 +317,7 @@ async def get_function_call_response( {"role": "user", "content": f"Query: {prompt}"}, ], "stream": False, - "task": TASKS.FUNCTION_CALLING, + "task": str(TASKS.FUNCTION_CALLING), } try: @@ -618,6 +618,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): content={"detail": str(e)}, ) + # Extract valves from the request body + valves = None + if "valves" in body: + valves = body["valves"] + del body["valves"] + # Extract session_id, chat_id and message_id from the request body session_id = None if "session_id" in body: @@ -632,24 +638,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): message_id = body["id"] del body["id"] - async def __event_emitter__(data): - await sio.emit( - "chat-events", - { - "chat_id": chat_id, - "message_id": message_id, - "data": data, - }, - to=session_id, - ) - - async def __event_call__(data): - response = await sio.call( - "chat-events", - {"chat_id": chat_id, "message_id": message_id, "data": data}, - to=session_id, - ) - return response + __event_emitter__ = await get_event_emitter( + {"chat_id": chat_id, "message_id": message_id, "session_id": session_id} + ) + __event_call__ = await get_event_call( + {"chat_id": chat_id, "message_id": message_id, "session_id": session_id} + ) # Initialize data_items to store additional data to be sent to the client data_items = [] @@ -703,6 +697,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if len(citations) > 0: data_items.append({"citations": citations}) + body["metadata"] = { + "session_id": session_id, + "chat_id": chat_id, + "message_id": message_id, + "valves": valves, + } + modified_body_bytes = json.dumps(body).encode("utf-8") # Replace the request body with the modified one request._body = modified_body_bytes @@ -823,9 +824,6 @@ def filter_pipeline(payload, user): if "detail" in res: raise Exception(r.status_code, res["detail"]) - if "pipeline" not in app.state.MODELS[model_id] and "task" in payload: - del payload["task"] - return payload @@ -935,6 +933,7 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION async def get_all_models(): + # TODO: Optimize this function pipe_models = [] openai_models = [] ollama_models = [] @@ -961,6 +960,14 @@ async def get_all_models(): models = pipe_models + openai_models + ollama_models + global_action_ids = [ + function.id for function in Functions.get_global_action_functions() + ] + enabled_action_ids = [ + function.id + for function in Functions.get_functions_by_type("action", active_only=True) + ] + custom_models = Models.get_all_models() for custom_model in custom_models: if custom_model.base_model_id == None: @@ -971,9 +978,33 @@ async def get_all_models(): ): model["name"] = custom_model.name model["info"] = custom_model.model_dump() + + action_ids = [] + global_action_ids + if "info" in model and "meta" in model["info"]: + action_ids.extend(model["info"]["meta"].get("actionIds", [])) + action_ids = list(set(action_ids)) + action_ids = [ + action_id + for action_id in action_ids + if action_id in enabled_action_ids + ] + + model["actions"] = [] + for action_id in action_ids: + action = Functions.get_function_by_id(action_id) + model["actions"].append( + { + "id": action_id, + "name": action.name, + "description": action.meta.description, + "icon_url": action.meta.manifest.get("icon_url", None), + } + ) + else: owned_by = "openai" pipe = None + actions = [] for model in models: if ( @@ -983,6 +1014,27 @@ async def get_all_models(): owned_by = model["owned_by"] if "pipe" in model: pipe = model["pipe"] + + action_ids = [] + global_action_ids + if "info" in model and "meta" in model["info"]: + action_ids.extend(model["info"]["meta"].get("actionIds", [])) + action_ids = list(set(action_ids)) + action_ids = [ + action_id + for action_id in action_ids + if action_id in enabled_action_ids + ] + + actions = [ + { + "id": action_id, + "name": Functions.get_function_by_id(action_id).name, + "description": Functions.get_function_by_id( + action_id + ).meta.description, + } + for action_id in action_ids + ] break models.append( @@ -995,6 +1047,7 @@ async def get_all_models(): "info": custom_model.model_dump(), "preset": True, **({"pipe": pipe} if pipe is not None else {}), + "actions": actions, } ) @@ -1036,13 +1089,24 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", ) - model = app.state.MODELS[model_id] - pipe = model.get("pipe") - if pipe: + # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc. + task = None + if "task" in form_data: + task = form_data["task"] + del form_data["task"] + + if task: + if "metadata" in form_data: + form_data["metadata"]["task"] = task + else: + form_data["metadata"] = {"task": task} + + if model.get("pipe"): return await generate_function_chat_completion(form_data, user=user) if model["owned_by"] == "ollama": + print("generate_ollama_chat_completion") return await generate_ollama_chat_completion(form_data, user=user) else: return await generate_openai_chat_completion(form_data, user=user) @@ -1107,24 +1171,21 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): else: pass - async def __event_emitter__(event_data): - await sio.emit( - "chat-events", - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "data": event_data, - }, - to=data["session_id"], - ) + __event_emitter__ = await get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) - async def __event_call__(event_data): - response = await sio.call( - "chat-events", - {"chat_id": data["chat_id"], "message_id": data["id"], "data": event_data}, - to=data["session_id"], - ) - return response + __event_call__ = await get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) def get_priority(function_id): function = Functions.get_function_by_id(function_id) @@ -1222,6 +1283,107 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): return data +@app.post("/api/chat/actions/{action_id}") +async def chat_completed( + action_id: str, form_data: dict, user=Depends(get_verified_user) +): + action = Functions.get_function_by_id(action_id) + if not action: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Action not found", + ) + + data = form_data + model_id = data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + model = app.state.MODELS[model_id] + + __event_emitter__ = await get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + __event_call__ = await get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + } + ) + + if action_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[action_id] + else: + function_module, _, _ = load_function_module_by_id(action_id) + webui_app.state.FUNCTIONS[action_id] = function_module + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(action_id) + function_module.valves = function_module.Valves(**(valves if valves else {})) + + if hasattr(function_module, "action"): + try: + action = function_module.action + + # Get the signature of the function + sig = inspect.signature(action) + params = {"body": data} + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": action_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + action_id, user.id + ) + ) + except Exception as e: + print(e) + + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(action): + data = await action(**params) + else: + data = action(**params) + + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + + return data + + ################################## # # Task Endpoints @@ -1314,7 +1476,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): "stream": False, "max_tokens": 50, "chat_id": form_data.get("chat_id", None), - "task": TASKS.TITLE_GENERATION, + "task": str(TASKS.TITLE_GENERATION), } log.debug(payload) @@ -1367,7 +1529,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) "messages": [{"role": "user", "content": content}], "stream": False, "max_tokens": 30, - "task": TASKS.QUERY_GENERATION, + "task": str(TASKS.QUERY_GENERATION), } print(payload) @@ -1424,7 +1586,7 @@ Message: """{{prompt}}""" "stream": False, "max_tokens": 4, "chat_id": form_data.get("chat_id", None), - "task": TASKS.EMOJI_GENERATION, + "task": str(TASKS.EMOJI_GENERATION), } log.debug(payload) diff --git a/backend/migrations/env.py b/backend/migrations/env.py index 7035cf9176..8046abff3a 100644 --- a/backend/migrations/env.py +++ b/backend/migrations/env.py @@ -27,7 +27,7 @@ config = context.config # Interpret the config file for Python logging. # This line sets up loggers basically. if config.config_file_name is not None: - fileConfig(config.config_file_name) + fileConfig(config.config_file_name, disable_existing_loggers=False) # add your model's MetaData object here # for 'autogenerate' support diff --git a/backend/requirements.txt b/backend/requirements.txt index b2bc0e0ac3..61185796db 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,6 +1,6 @@ fastapi==0.111.0 uvicorn[standard]==0.22.0 -pydantic==2.7.1 +pydantic==2.8.2 python-multipart==0.0.9 Flask==3.0.3 @@ -12,7 +12,7 @@ passlib[bcrypt]==1.7.4 requests==2.32.3 aiohttp==3.9.5 -sqlalchemy==2.0.30 +sqlalchemy==2.0.31 alembic==1.13.2 peewee==3.17.6 peewee-migrate==1.12.2 @@ -38,12 +38,12 @@ langchain-community==0.2.6 langchain-chroma==0.1.2 fake-useragent==1.5.1 -chromadb==0.5.3 +chromadb==0.5.4 sentence-transformers==3.0.1 pypdf==4.2.0 docx2txt==0.8 python-pptx==0.6.23 -unstructured==0.14.9 +unstructured==0.14.10 Markdown==3.6 pypandoc==1.13 pandas==2.2.2 @@ -71,7 +71,7 @@ pytube==15.0.0 extract_msg pydub -duckduckgo-search~=6.1.7 +duckduckgo-search~=6.1.12 ## Tests docker~=7.1.0 diff --git a/package-lock.json b/package-lock.json index 9eb09d4212..ebae1e7dc6 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.3.8", + "version": "0.3.9", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.3.8", + "version": "0.3.9", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-python": "^6.1.6", diff --git a/package.json b/package.json index 080f8ed5ba..53a22b6997 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.3.8", + "version": "0.3.9", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 9558e98f50..c2e90855b6 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -104,6 +104,45 @@ export const chatCompleted = async (token: string, body: ChatCompletedForm) => { return res; }; +type ChatActionForm = { + model: string; + messages: string[]; + chat_id: string; +}; + +export const chatAction = async (token: string, action_id: string, body: ChatActionForm) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/chat/actions/${action_id}`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify(body) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = err; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getTaskConfig = async (token: string = '') => { let error = null; diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index b32e544eef..5c0a47b357 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -425,7 +425,7 @@ export const resetUploadDir = async (token: string) => { export const resetVectorDB = async (token: string) => { let error = null; - const res = await fetch(`${RAG_API_BASE_URL}/reset`, { + const res = await fetch(`${RAG_API_BASE_URL}/reset/db`, { method: 'GET', headers: { Accept: 'application/json', diff --git a/src/lib/components/admin/Settings/Images.svelte b/src/lib/components/admin/Settings/Images.svelte index 80b5c93ba9..9838792f28 100644 --- a/src/lib/components/admin/Settings/Images.svelte +++ b/src/lib/components/admin/Settings/Images.svelte @@ -282,6 +282,7 @@
diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index a6800f5127..b7b8dedea2 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -52,7 +52,7 @@ import { createOpenAITextStream } from '$lib/apis/streaming'; import { queryMemory } from '$lib/apis/memories'; import { getAndUpdateUserLocation, getUserSettings } from '$lib/apis/users'; - import { chatCompleted, generateTitle, generateSearchQuery } from '$lib/apis'; + import { chatCompleted, generateTitle, generateSearchQuery, chatAction } from '$lib/apis'; import Banner from '../common/Banner.svelte'; import MessageInput from '$lib/components/chat/MessageInput.svelte'; @@ -78,6 +78,8 @@ let showEventConfirmation = false; let eventConfirmationTitle = ''; let eventConfirmationMessage = ''; + let eventConfirmationInput = false; + let eventConfirmationInputPlaceholder = ''; let eventCallback = null; let showModelSelector = true; @@ -96,6 +98,8 @@ let title = ''; let prompt = ''; + + let chatFiles = []; let files = []; let messages = []; let history = { @@ -104,6 +108,7 @@ }; let params = {}; + let valves = {}; $: if (history.currentId !== null) { let _messages = []; @@ -156,12 +161,27 @@ } else { message.citations = [data]; } + } else if (type === 'message') { + message.content += data.content; + } else if (type === 'replace') { + message.content = data.content; } else if (type === 'confirmation') { eventCallback = cb; + + eventConfirmationInput = false; showEventConfirmation = true; eventConfirmationTitle = data.title; eventConfirmationMessage = data.message; + } else if (type === 'input') { + eventCallback = cb; + + eventConfirmationInput = true; + showEventConfirmation = true; + + eventConfirmationTitle = data.title; + eventConfirmationMessage = data.message; + eventConfirmationInputPlaceholder = data.placeholder; } else { console.log('Unknown message type', data); } @@ -315,6 +335,7 @@ } params = chatContent?.params ?? {}; + chatFiles = chatContent?.files ?? {}; autoScroll = true; await tick(); @@ -347,7 +368,7 @@ } }; - const chatCompletedHandler = async (modelId, responseMessageId, messages) => { + const chatCompletedHandler = async (chatId, modelId, responseMessageId, messages) => { await mermaid.run({ querySelector: '.mermaid' }); @@ -361,7 +382,7 @@ info: m.info ? m.info : undefined, timestamp: m.timestamp })), - chat_id: $chatId, + chat_id: chatId, session_id: $socket?.id, id: responseMessageId }).catch((error) => { @@ -383,6 +404,65 @@ }; } } + + if ($chatId == chatId) { + if ($settings.saveChatHistory ?? true) { + chat = await updateChatById(localStorage.token, chatId, { + models: selectedModels, + messages: messages, + history: history, + params: params, + files: chatFiles + }); + await chats.set(await getChatList(localStorage.token)); + } + } + }; + + const chatActionHandler = async (chatId, actionId, modelId, responseMessageId) => { + const res = await chatAction(localStorage.token, actionId, { + model: modelId, + messages: messages.map((m) => ({ + id: m.id, + role: m.role, + content: m.content, + info: m.info ? m.info : undefined, + timestamp: m.timestamp + })), + chat_id: chatId, + session_id: $socket?.id, + id: responseMessageId + }).catch((error) => { + toast.error(error); + messages.at(-1).error = { content: error }; + return null; + }); + + if (res !== null) { + // Update chat history with the new messages + for (const message of res.messages) { + history.messages[message.id] = { + ...history.messages[message.id], + ...(history.messages[message.id].content !== message.content + ? { originalContent: history.messages[message.id].content } + : {}), + ...message + }; + } + } + + if ($chatId == chatId) { + if ($settings.saveChatHistory ?? true) { + chat = await updateChatById(localStorage.token, chatId, { + models: selectedModels, + messages: messages, + history: history, + params: params, + files: chatFiles + }); + await chats.set(await getChatList(localStorage.token)); + } + } }; const getChatEventEmitter = async (modelId: string, chatId: string = '') => { @@ -439,6 +519,13 @@ } const _files = JSON.parse(JSON.stringify(files)); + chatFiles.push(..._files.filter((item) => ['doc', 'file', 'collection'].includes(item.type))); + chatFiles = chatFiles.filter( + // Remove duplicates + (item, index, array) => + array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index + ); + files = []; prompt = ''; @@ -679,25 +766,10 @@ } }); - let files = []; + let files = JSON.parse(JSON.stringify(chatFiles)); if (model?.info?.meta?.knowledge ?? false) { - files = model.info.meta.knowledge; + files.push(...model.info.meta.knowledge); } - const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1); - - files = [ - ...files, - ...(lastUserMessage?.files?.filter((item) => - ['doc', 'file', 'collection', 'web_search_results'].includes(item.type) - ) ?? []), - ...(responseMessage?.files?.filter((item) => - ['doc', 'file', 'collection', 'web_search_results'].includes(item.type) - ) ?? []) - ].filter( - // Remove duplicates - (item, index, array) => - array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index - ); eventTarget.dispatchEvent( new CustomEvent('chat:start', { @@ -729,6 +801,7 @@ keep_alive: $settings.keepAlive ?? undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, + ...(Object.keys(valves).length ? { valves } : {}), session_id: $socket?.id, chat_id: $chatId, id: responseMessageId @@ -752,7 +825,7 @@ controller.abort('User: Stop Response'); } else { const messages = createMessagesList(responseMessageId); - await chatCompletedHandler(model.id, responseMessageId, messages); + await chatCompletedHandler(_chatId, model.id, responseMessageId, messages); } _response = responseMessage.content; @@ -860,7 +933,8 @@ messages: messages, history: history, models: selectedModels, - params: params + params: params, + files: chatFiles }); await chats.set(await getChatList(localStorage.token)); } @@ -914,7 +988,7 @@ scrollToBottom(); } - if (messages.length == 2 && messages.at(1).content !== '') { + if (messages.length == 2 && messages.at(1).content !== '' && selectedModels[0] === model.id) { window.history.replaceState(history.state, '', `/c/${_chatId}`); const _title = await generateChatTitle(userPrompt); await setChatTitle(_chatId, _title); @@ -927,24 +1001,10 @@ let _response = null; const responseMessage = history.messages[responseMessageId]; - let files = []; + let files = JSON.parse(JSON.stringify(chatFiles)); if (model?.info?.meta?.knowledge ?? false) { - files = model.info.meta.knowledge; + files.push(...model.info.meta.knowledge); } - const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1); - files = [ - ...files, - ...(lastUserMessage?.files?.filter((item) => - ['doc', 'file', 'collection', 'web_search_results'].includes(item.type) - ) ?? []), - ...(responseMessage?.files?.filter((item) => - ['doc', 'file', 'collection', 'web_search_results'].includes(item.type) - ) ?? []) - ].filter( - // Remove duplicates - (item, index, array) => - array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index - ); scrollToBottom(); @@ -1033,6 +1093,7 @@ max_tokens: params?.max_tokens ?? $settings?.params?.max_tokens ?? undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, + ...(Object.keys(valves).length ? { valves } : {}), session_id: $socket?.id, chat_id: $chatId, id: responseMessageId @@ -1064,7 +1125,7 @@ } else { const messages = createMessagesList(responseMessageId); - await chatCompletedHandler(model.id, responseMessageId, messages); + await chatCompletedHandler(_chatId, model.id, responseMessageId, messages); } _response = responseMessage.content; @@ -1137,7 +1198,8 @@ models: selectedModels, messages: messages, history: history, - params: params + params: params, + files: chatFiles }); await chats.set(await getChatList(localStorage.token)); } @@ -1175,7 +1237,7 @@ scrollToBottom(); } - if (messages.length == 2) { + if (messages.length == 2 && selectedModels[0] === model.id) { window.history.replaceState(history.state, '', `/c/${_chatId}`); const _title = await generateChatTitle(userPrompt); @@ -1408,8 +1470,14 @@ bind:show={showEventConfirmation} title={eventConfirmationTitle} message={eventConfirmationMessage} + input={eventConfirmationInput} + inputPlaceholder={eventConfirmationInputPlaceholder} on:confirm={(e) => { - eventCallback(true); + if (e.detail) { + eventCallback(e.detail); + } else { + eventCallback(true); + } }} on:cancel={() => { eventCallback(false); @@ -1511,6 +1579,7 @@ {sendPrompt} {continueGeneration} {regenerateResponse} + {chatActionHandler} />
@@ -1539,6 +1608,18 @@ - + { + const model = $models.find((m) => m.id === e); + if (model) { + return [...a, model]; + } + return a; + }, [])} + bind:show={showControls} + bind:chatFiles + bind:params + bind:valves + /> {/if} diff --git a/src/lib/components/chat/ChatControls.svelte b/src/lib/components/chat/ChatControls.svelte index 612af33785..f67e6d6efd 100644 --- a/src/lib/components/chat/ChatControls.svelte +++ b/src/lib/components/chat/ChatControls.svelte @@ -6,7 +6,12 @@ export let show = false; + export let models = []; + export let chatId = null; + + export let chatFiles = []; + export let valves = {}; export let params = {}; let largeScreen = false; @@ -43,6 +48,9 @@ on:close={() => { show = false; }} + {models} + bind:chatFiles + bind:valves bind:params /> @@ -56,6 +64,9 @@ on:close={() => { show = false; }} + {models} + bind:chatFiles + bind:valves bind:params /> diff --git a/src/lib/components/chat/Controls/Controls.svelte b/src/lib/components/chat/Controls/Controls.svelte index 2cb3bceeb7..fe2036286b 100644 --- a/src/lib/components/chat/Controls/Controls.svelte +++ b/src/lib/components/chat/Controls/Controls.svelte @@ -5,7 +5,13 @@ import XMark from '$lib/components/icons/XMark.svelte'; import AdvancedParams from '../Settings/Advanced/AdvancedParams.svelte'; + import Valves from '$lib/components/common/Valves.svelte'; + import FileItem from '$lib/components/common/FileItem.svelte'; + export let models = []; + + export let chatFiles = []; + export let valves = {}; export let params = {}; @@ -23,15 +29,51 @@
+ {#if chatFiles.length > 0} +
+
{$i18n.t('Files')}
+ +
+ {#each chatFiles as file} + { + // Remove the file from the chatFiles array + chatFiles = chatFiles.filter((f) => f.id !== file.id); + }} + /> + {/each} +
+
+ +
+ {/if} + + {#if models.length === 1 && models[0]?.pipe?.valves_spec} +
+
{$i18n.t('Valves')}
+ +
+ +
+
+ +
+ {/if} +
-
System Prompt
+
{$i18n.t('System Prompt')}