diff --git a/.github/workflows/docker-build.yaml b/.github/workflows/docker-build.yaml index da40f56ffe..0e62be3d90 100644 --- a/.github/workflows/docker-build.yaml +++ b/.github/workflows/docker-build.yaml @@ -11,8 +11,6 @@ on: env: REGISTRY: ghcr.io - IMAGE_NAME: ${{ github.repository }} - FULL_IMAGE_NAME: ghcr.io/${{ github.repository }} jobs: build-main-image: @@ -28,6 +26,15 @@ jobs: - linux/arm64 steps: + # GitHub Packages requires the entire repository name to be in lowercase + # although the repository owner has a lowercase username, this prevents some people from running actions after forking + - name: Set repository and image name to lowercase + run: | + echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV} + echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV} + env: + IMAGE_NAME: '${{ github.repository }}' + - name: Prepare run: | platform=${{ matrix.platform }} @@ -116,6 +123,15 @@ jobs: - linux/arm64 steps: + # GitHub Packages requires the entire repository name to be in lowercase + # although the repository owner has a lowercase username, this prevents some people from running actions after forking + - name: Set repository and image name to lowercase + run: | + echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV} + echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV} + env: + IMAGE_NAME: '${{ github.repository }}' + - name: Prepare run: | platform=${{ matrix.platform }} @@ -207,6 +223,15 @@ jobs: - linux/arm64 steps: + # GitHub Packages requires the entire repository name to be in lowercase + # although the repository owner has a lowercase username, this prevents some people from running actions after forking + - name: Set repository and image name to lowercase + run: | + echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV} + echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV} + env: + IMAGE_NAME: '${{ github.repository }}' + - name: Prepare run: | platform=${{ matrix.platform }} @@ -289,6 +314,15 @@ jobs: runs-on: ubuntu-latest needs: [ build-main-image ] steps: + # GitHub Packages requires the entire repository name to be in lowercase + # although the repository owner has a lowercase username, this prevents some people from running actions after forking + - name: Set repository and image name to lowercase + run: | + echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV} + echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV} + env: + IMAGE_NAME: '${{ github.repository }}' + - name: Download digests uses: actions/download-artifact@v4 with: @@ -335,6 +369,15 @@ jobs: runs-on: ubuntu-latest needs: [ build-cuda-image ] steps: + # GitHub Packages requires the entire repository name to be in lowercase + # although the repository owner has a lowercase username, this prevents some people from running actions after forking + - name: Set repository and image name to lowercase + run: | + echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV} + echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV} + env: + IMAGE_NAME: '${{ github.repository }}' + - name: Download digests uses: actions/download-artifact@v4 with: @@ -382,6 +425,15 @@ jobs: runs-on: ubuntu-latest needs: [ build-ollama-image ] steps: + # GitHub Packages requires the entire repository name to be in lowercase + # although the repository owner has a lowercase username, this prevents some people from running actions after forking + - name: Set repository and image name to lowercase + run: | + echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV} + echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV} + env: + IMAGE_NAME: '${{ github.repository }}' + - name: Download digests uses: actions/download-artifact@v4 with: diff --git a/CHANGELOG.md b/CHANGELOG.md index bfff72eed2..6756d105b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,30 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.3.5] - 2024-06-16 + +### Added + +- **📞 Enhanced Voice Call**: Text-to-speech (TTS) callback now operates in real-time for each sentence, reducing latency by not waiting for full completion. +- **👆 Tap to Interrupt**: During a call, you can now stop the assistant from speaking by simply tapping, instead of using voice. This resolves the issue of the speaker's voice being mistakenly registered as input. +- **😊 Emoji Call**: Toggle this feature on from the Settings > Interface, allowing LLMs to express emotions using emojis during voice calls for a more dynamic interaction. +- **🖱️ Quick Archive/Delete**: Use the Shift key + mouseover on the chat list to swiftly archive or delete items. +- **📝 Markdown Support in Model Descriptions**: You can now format model descriptions with markdown, enabling bold text, links, etc. +- **🧠 Editable Memories**: Adds the capability to modify memories. +- **📋 Admin Panel Sorting**: Introduces the ability to sort users/chats within the admin panel. +- **🌑 Dark Mode for Quick Selectors**: Dark mode now available for chat quick selectors (prompts, models, documents). +- **🔧 Advanced Parameters**: Adds 'num_keep' and 'num_batch' to advanced parameters for customization. +- **📅 Dynamic System Prompts**: New variables '{{CURRENT_DATETIME}}', '{{CURRENT_TIME}}', '{{USER_LOCATION}}' added for system prompts. Ensure '{{USER_LOCATION}}' is toggled on from Settings > Interface. +- **🌐 Tavily Web Search**: Includes Tavily as a web search provider option. +- **🖊️ Federated Auth Usernames**: Ability to set user names for federated authentication. +- **🔗 Auto Clean URLs**: When adding connection URLs, trailing slashes are now automatically removed. +- **🌐 Enhanced Translations**: Improved Chinese and Swedish translations. + +### Fixed + +- **⏳ AIOHTTP_CLIENT_TIMEOUT**: Introduced a new environment variable 'AIOHTTP_CLIENT_TIMEOUT' for requests to Ollama lasting longer than 5 minutes. Default is 300 seconds; set to blank ('') for no timeout. +- **❌ Message Delete Freeze**: Resolved an issue where message deletion would sometimes cause the web UI to freeze. + ## [0.3.4] - 2024-06-12 ### Fixed diff --git a/README.md b/README.md index 444002fbc6..f3cfe0d274 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature- - 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query. -- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, and `Serply` and inject the results directly into your chat experience. +- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo` and `TavilySearch` and inject the results directly into your chat experience. - 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions. @@ -160,7 +160,7 @@ Check our Migration Guide available in our [Open WebUI Documentation](https://do If you want to try out the latest bleeding-edge features and are okay with occasional instability, you can use the `:dev` tag like this: ```bash -docker run -d -p 3000:8080 -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:dev +docker run -d -p 3000:8080 -v open-webui:/app/backend/data --name open-webui --add-host=host.docker.internal:host-gateway --restart always ghcr.io/open-webui/open-webui:dev ``` ## What's Next? 🌟 diff --git a/TROUBLESHOOTING.md b/TROUBLESHOOTING.md index 8e8f89da02..9bf242381c 100644 --- a/TROUBLESHOOTING.md +++ b/TROUBLESHOOTING.md @@ -18,6 +18,10 @@ If you're experiencing connection issues, it’s often due to the WebUI docker c docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main ``` +### Error on Slow Reponses for Ollama + +Open WebUI has a default timeout of 5 minutes for Ollama to finish generating the response. If needed, this can be adjusted via the environment variable AIOHTTP_CLIENT_TIMEOUT, which sets the timeout in seconds. + ### General Connection Errors **Ensure Ollama Version is Up-to-Date**: Always start by checking that you have the latest version of Ollama. Visit [Ollama's official site](https://ollama.com/) for the latest updates. diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index 4419ccf199..af7e5592d5 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -37,6 +37,10 @@ from config import ( ENABLE_IMAGE_GENERATION, AUTOMATIC1111_BASE_URL, COMFYUI_BASE_URL, + COMFYUI_CFG_SCALE, + COMFYUI_SAMPLER, + COMFYUI_SCHEDULER, + COMFYUI_SD3, IMAGES_OPENAI_API_BASE_URL, IMAGES_OPENAI_API_KEY, IMAGE_GENERATION_MODEL, @@ -78,6 +82,10 @@ app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL app.state.config.IMAGE_SIZE = IMAGE_SIZE app.state.config.IMAGE_STEPS = IMAGE_STEPS +app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE +app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER +app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER +app.state.config.COMFYUI_SD3 = COMFYUI_SD3 @app.get("/config") @@ -457,6 +465,18 @@ def generate_image( if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt + if app.state.config.COMFYUI_CFG_SCALE: + data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE + + if app.state.config.COMFYUI_SAMPLER is not None: + data["sampler"] = app.state.config.COMFYUI_SAMPLER + + if app.state.config.COMFYUI_SCHEDULER is not None: + data["scheduler"] = app.state.config.COMFYUI_SCHEDULER + + if app.state.config.COMFYUI_SD3 is not None: + data["sd3"] = app.state.config.COMFYUI_SD3 + data = ImageGenerationPayload(**data) res = comfyui_generate_image( diff --git a/backend/apps/images/utils/comfyui.py b/backend/apps/images/utils/comfyui.py index 05df1c1665..599b1f3379 100644 --- a/backend/apps/images/utils/comfyui.py +++ b/backend/apps/images/utils/comfyui.py @@ -190,6 +190,10 @@ class ImageGenerationPayload(BaseModel): width: int height: int n: int = 1 + cfg_scale: Optional[float] = None + sampler: Optional[str] = None + scheduler: Optional[str] = None + sd3: Optional[bool] = None def comfyui_generate_image( @@ -199,6 +203,18 @@ def comfyui_generate_image( comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT) + if payload.cfg_scale: + comfyui_prompt["3"]["inputs"]["cfg"] = payload.cfg_scale + + if payload.sampler: + comfyui_prompt["3"]["inputs"]["sampler"] = payload.sampler + + if payload.scheduler: + comfyui_prompt["3"]["inputs"]["scheduler"] = payload.scheduler + + if payload.sd3: + comfyui_prompt["5"]["class_type"] = "EmptySD3LatentImage" + comfyui_prompt["4"]["inputs"]["ckpt_name"] = model comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n comfyui_prompt["5"]["inputs"]["width"] = payload.width diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 1447554188..81a3b2a0e8 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -46,6 +46,7 @@ from config import ( SRC_LOG_LEVELS, OLLAMA_BASE_URLS, ENABLE_OLLAMA_API, + AIOHTTP_CLIENT_TIMEOUT, ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, UPLOAD_DIR, @@ -154,7 +155,9 @@ async def cleanup_response( async def post_streaming_url(url: str, payload: str): r = None try: - session = aiohttp.ClientSession(trust_env=True) + session = aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) r = await session.post(url, data=payload) r.raise_for_status() @@ -751,6 +754,14 @@ async def generate_chat_completion( if model_info.params.get("num_ctx", None): payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None) + if model_info.params.get("num_batch", None): + payload["options"]["num_batch"] = model_info.params.get( + "num_batch", None + ) + + if model_info.params.get("num_keep", None): + payload["options"]["num_keep"] = model_info.params.get("num_keep", None) + if model_info.params.get("repeat_last_n", None): payload["options"]["repeat_last_n"] = model_info.params.get( "repeat_last_n", None @@ -839,8 +850,7 @@ async def generate_chat_completion( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - - print(payload) + log.debug(payload) return await post_streaming_url(f"{url}/api/chat", json.dumps(payload)) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 93f913dea2..c09c030d2d 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -430,13 +430,11 @@ async def generate_chat_completion( # Convert the modified body back to JSON payload = json.dumps(payload) - print(payload) + log.debug(payload) url = app.state.config.OPENAI_API_BASE_URLS[idx] key = app.state.config.OPENAI_API_KEYS[idx] - print(payload) - headers = {} headers["Authorization"] = f"Bearer {key}" headers["Content-Type"] = "application/json" diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 0e493eaaa3..4bd5da86cc 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -73,6 +73,7 @@ from apps.rag.search.serper import search_serper from apps.rag.search.serpstack import search_serpstack from apps.rag.search.serply import search_serply from apps.rag.search.duckduckgo import search_duckduckgo +from apps.rag.search.tavily import search_tavily from utils.misc import ( calculate_sha256, @@ -119,6 +120,7 @@ from config import ( SERPSTACK_HTTPS, SERPER_API_KEY, SERPLY_API_KEY, + TAVILY_API_KEY, RAG_WEB_SEARCH_RESULT_COUNT, RAG_WEB_SEARCH_CONCURRENT_REQUESTS, RAG_EMBEDDING_OPENAI_BATCH_SIZE, @@ -172,6 +174,7 @@ app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS app.state.config.SERPER_API_KEY = SERPER_API_KEY app.state.config.SERPLY_API_KEY = SERPLY_API_KEY +app.state.config.TAVILY_API_KEY = TAVILY_API_KEY app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS @@ -400,6 +403,7 @@ async def get_rag_config(user=Depends(get_admin_user)): "serpstack_https": app.state.config.SERPSTACK_HTTPS, "serper_api_key": app.state.config.SERPER_API_KEY, "serply_api_key": app.state.config.SERPLY_API_KEY, + "tavily_api_key": app.state.config.TAVILY_API_KEY, "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, @@ -428,6 +432,7 @@ class WebSearchConfig(BaseModel): serpstack_https: Optional[bool] = None serper_api_key: Optional[str] = None serply_api_key: Optional[str] = None + tavily_api_key: Optional[str] = None result_count: Optional[int] = None concurrent_requests: Optional[int] = None @@ -479,6 +484,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key + app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( form_data.web.search.concurrent_requests @@ -508,6 +514,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ "serpstack_https": app.state.config.SERPSTACK_HTTPS, "serper_api_key": app.state.config.SERPER_API_KEY, "serply_api_key": app.state.config.SERPLY_API_KEY, + "tavily_api_key": app.state.config.TAVILY_API_KEY, "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, }, @@ -756,7 +763,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: - SERPSTACK_API_KEY - SERPER_API_KEY - SERPLY_API_KEY - + - TAVILY_API_KEY Args: query (str): The query to search for """ @@ -825,6 +832,15 @@ def search_web(engine: str, query: str) -> list[SearchResult]: raise Exception("No SERPLY_API_KEY found in environment variables") elif engine == "duckduckgo": return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT) + elif engine == "tavily": + if app.state.config.TAVILY_API_KEY: + return search_tavily( + app.state.config.TAVILY_API_KEY, + query, + app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + ) + else: + raise Exception("No TAVILY_API_KEY found in environment variables") else: raise Exception("No search engine API key found in environment variables") diff --git a/backend/apps/rag/search/tavily.py b/backend/apps/rag/search/tavily.py new file mode 100644 index 0000000000..b15d6ef9d5 --- /dev/null +++ b/backend/apps/rag/search/tavily.py @@ -0,0 +1,39 @@ +import logging + +import requests + +from apps.rag.search.main import SearchResult +from config import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]: + """Search using Tavily's Search API and return the results as a list of SearchResult objects. + + Args: + api_key (str): A Tavily Search API key + query (str): The query to search for + + Returns: + List[SearchResult]: A list of search results + """ + url = "https://api.tavily.com/search" + data = {"query": query, "api_key": api_key} + + response = requests.post(url, json=data) + response.raise_for_status() + + json_response = response.json() + + raw_search_results = json_response.get("results", []) + + return [ + SearchResult( + link=result["url"], + title=result.get("title", ""), + snippet=result.get("content"), + ) + for result in raw_search_results[:count] + ] diff --git a/backend/apps/webui/internal/migrations/013_add_user_info.py b/backend/apps/webui/internal/migrations/013_add_user_info.py new file mode 100644 index 0000000000..0f68669cca --- /dev/null +++ b/backend/apps/webui/internal/migrations/013_add_user_info.py @@ -0,0 +1,48 @@ +"""Peewee migrations -- 002_add_local_sharing.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + # Adding fields info to the 'user' table + migrator.add_fields("user", info=pw.TextField(null=True)) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + # Remove the settings field + migrator.remove_fields("user", "info") diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 62a0a7a7b5..190d2d1c3f 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -25,6 +25,7 @@ from config import ( USER_PERMISSIONS, WEBHOOK_URL, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, + WEBUI_AUTH_TRUSTED_NAME_HEADER, JWT_EXPIRES_IN, WEBUI_BANNERS, ENABLE_COMMUNITY_SHARING, @@ -40,6 +41,7 @@ app.state.config = AppConfig() app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER +app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS diff --git a/backend/apps/webui/models/memories.py b/backend/apps/webui/models/memories.py index 70e5577e94..ef63674abb 100644 --- a/backend/apps/webui/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -65,6 +65,20 @@ class MemoriesTable: else: return None + def update_memory_by_id( + self, + id: str, + content: str, + ) -> Optional[MemoryModel]: + try: + memory = Memory.get(Memory.id == id) + memory.content = content + memory.updated_at = int(time.time()) + memory.save() + return MemoryModel(**model_to_dict(memory)) + except: + return None + def get_memories(self) -> List[MemoryModel]: try: memories = Memory.select() diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 48811e8af0..485a9eea4e 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -26,6 +26,7 @@ class User(Model): api_key = CharField(null=True, unique=True) settings = JSONField(null=True) + info = JSONField(null=True) class Meta: database = DB @@ -50,6 +51,7 @@ class UserModel(BaseModel): api_key: Optional[str] = None settings: Optional[UserSettings] = None + info: Optional[dict] = None #################### diff --git a/backend/apps/webui/routers/auths.py b/backend/apps/webui/routers/auths.py index d45879a243..16e3957378 100644 --- a/backend/apps/webui/routers/auths.py +++ b/backend/apps/webui/routers/auths.py @@ -33,7 +33,11 @@ from utils.utils import ( from utils.misc import parse_duration, validate_email_format from utils.webhook import post_webhook from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES -from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER +from config import ( + WEBUI_AUTH, + WEBUI_AUTH_TRUSTED_EMAIL_HEADER, + WEBUI_AUTH_TRUSTED_NAME_HEADER, +) router = APIRouter() @@ -110,11 +114,16 @@ async def signin(request: Request, form_data: SigninForm): raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER) trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower() + trusted_name = trusted_email + if WEBUI_AUTH_TRUSTED_NAME_HEADER: + trusted_name = request.headers.get( + WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email + ) if not Users.get_user_by_email(trusted_email.lower()): await signup( request, SignupForm( - email=trusted_email, password=str(uuid.uuid4()), name=trusted_email + email=trusted_email, password=str(uuid.uuid4()), name=trusted_name ), ) user = Auths.authenticate_user_by_trusted_header(trusted_email) diff --git a/backend/apps/webui/routers/memories.py b/backend/apps/webui/routers/memories.py index 6448ebe1ee..3832fe9a16 100644 --- a/backend/apps/webui/routers/memories.py +++ b/backend/apps/webui/routers/memories.py @@ -44,6 +44,10 @@ class AddMemoryForm(BaseModel): content: str +class MemoryUpdateModel(BaseModel): + content: Optional[str] = None + + @router.post("/add", response_model=Optional[MemoryModel]) async def add_memory( request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user) @@ -62,6 +66,34 @@ async def add_memory( return memory +@router.post("/{memory_id}/update", response_model=Optional[MemoryModel]) +async def update_memory_by_id( + memory_id: str, + request: Request, + form_data: MemoryUpdateModel, + user=Depends(get_verified_user), +): + memory = Memories.update_memory_by_id(memory_id, form_data.content) + if memory is None: + raise HTTPException(status_code=404, detail="Memory not found") + + if form_data.content is not None: + memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content) + collection = CHROMA_CLIENT.get_or_create_collection( + name=f"user-memory-{user.id}" + ) + collection.upsert( + documents=[form_data.content], + ids=[memory.id], + embeddings=[memory_embedding], + metadatas=[ + {"created_at": memory.created_at, "updated_at": memory.updated_at} + ], + ) + + return memory + + ############################ # QueryMemory ############################ diff --git a/backend/apps/webui/routers/users.py b/backend/apps/webui/routers/users.py index eccafde103..270d72a238 100644 --- a/backend/apps/webui/routers/users.py +++ b/backend/apps/webui/routers/users.py @@ -115,6 +115,52 @@ async def update_user_settings_by_session_user( ) +############################ +# GetUserInfoBySessionUser +############################ + + +@router.get("/user/info", response_model=Optional[dict]) +async def get_user_info_by_session_user(user=Depends(get_verified_user)): + user = Users.get_user_by_id(user.id) + if user: + return user.info + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.USER_NOT_FOUND, + ) + + +############################ +# UpdateUserInfoBySessionUser +############################ + + +@router.post("/user/info/update", response_model=Optional[dict]) +async def update_user_settings_by_session_user( + form_data: dict, user=Depends(get_verified_user) +): + user = Users.get_user_by_id(user.id) + if user: + if user.info is None: + user.info = {} + + user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}}) + if user: + return user.info + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.USER_NOT_FOUND, + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.USER_NOT_FOUND, + ) + + ############################ # GetUserById ############################ diff --git a/backend/config.py b/backend/config.py index 30a23f29ee..1a38a450db 100644 --- a/backend/config.py +++ b/backend/config.py @@ -294,6 +294,7 @@ WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None ) +WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None) JWT_EXPIRES_IN = PersistentConfig( "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") ) @@ -425,6 +426,14 @@ OLLAMA_API_BASE_URL = os.environ.get( ) OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "") +AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "300") + +if AIOHTTP_CLIENT_TIMEOUT == "": + AIOHTTP_CLIENT_TIMEOUT = None +else: + AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT) + + K8S_FLAG = os.environ.get("K8S_FLAG", "") USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false") @@ -942,6 +951,11 @@ SERPLY_API_KEY = PersistentConfig( os.getenv("SERPLY_API_KEY", ""), ) +TAVILY_API_KEY = PersistentConfig( + "TAVILY_API_KEY", + "rag.web.search.tavily_api_key", + os.getenv("TAVILY_API_KEY", ""), +) RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig( "RAG_WEB_SEARCH_RESULT_COUNT", @@ -994,6 +1008,30 @@ COMFYUI_BASE_URL = PersistentConfig( os.getenv("COMFYUI_BASE_URL", ""), ) +COMFYUI_CFG_SCALE = PersistentConfig( + "COMFYUI_CFG_SCALE", + "image_generation.comfyui.cfg_scale", + os.getenv("COMFYUI_CFG_SCALE", ""), +) + +COMFYUI_SAMPLER = PersistentConfig( + "COMFYUI_SAMPLER", + "image_generation.comfyui.sampler", + os.getenv("COMFYUI_SAMPLER", ""), +) + +COMFYUI_SCHEDULER = PersistentConfig( + "COMFYUI_SCHEDULER", + "image_generation.comfyui.scheduler", + os.getenv("COMFYUI_SCHEDULER", ""), +) + +COMFYUI_SD3 = PersistentConfig( + "COMFYUI_SD3", + "image_generation.comfyui.sd3", + os.environ.get("COMFYUI_SD3", "").lower() == "true", +) + IMAGES_OPENAI_API_BASE_URL = PersistentConfig( "IMAGES_OPENAI_API_BASE_URL", "image_generation.openai.api_base_url", diff --git a/backend/main.py b/backend/main.py index de8827d12d..04f8861621 100644 --- a/backend/main.py +++ b/backend/main.py @@ -494,6 +494,9 @@ def filter_pipeline(payload, user): if "title" in payload: del payload["title"] + if "task" in payload: + del payload["task"] + return payload @@ -761,7 +764,12 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE content = title_generation_template( - template, form_data["prompt"], user.model_dump() + template, + form_data["prompt"], + { + "name": user.name, + "location": user.info.get("location") if user.info else None, + }, ) payload = { @@ -773,7 +781,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): "title": True, } - print(payload) + log.debug(payload) try: payload = filter_pipeline(payload, user) @@ -827,7 +835,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE content = search_query_generation_template( - template, form_data["prompt"], user.model_dump() + template, form_data["prompt"], {"name": user.name} ) payload = { @@ -835,6 +843,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) "messages": [{"role": "user", "content": content}], "stream": False, "max_tokens": 30, + "task": True, } print(payload) @@ -855,6 +864,75 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) return await generate_openai_chat_completion(payload, user=user) +@app.post("/api/task/emoji/completions") +async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): + print("generate_emoji") + + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + if app.state.MODELS[model_id]["owned_by"] == "ollama": + if app.state.config.TASK_MODEL: + task_model_id = app.state.config.TASK_MODEL + if task_model_id in app.state.MODELS: + model_id = task_model_id + else: + if app.state.config.TASK_MODEL_EXTERNAL: + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + if task_model_id in app.state.MODELS: + model_id = task_model_id + + print(model_id) + model = app.state.MODELS[model_id] + + template = ''' +Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). + +Message: """{{prompt}}""" +''' + + content = title_generation_template( + template, + form_data["prompt"], + { + "name": user.name, + "location": user.info.get("location") if user.info else None, + }, + ) + + payload = { + "model": model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "max_tokens": 4, + "chat_id": form_data.get("chat_id", None), + "task": True, + } + + log.debug(payload) + + try: + payload = filter_pipeline(payload, user) + except Exception as e: + return JSONResponse( + status_code=e.args[0], + content={"detail": e.args[1]}, + ) + + if model["owned_by"] == "ollama": + return await generate_ollama_chat_completion( + OpenAIChatCompletionForm(**payload), user=user + ) + else: + return await generate_openai_chat_completion(payload, user=user) + + @app.post("/api/task/tools/completions") async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)): print("get_tools_function_calling") diff --git a/backend/utils/task.py b/backend/utils/task.py index 615febcdcd..ea277eb0be 100644 --- a/backend/utils/task.py +++ b/backend/utils/task.py @@ -6,24 +6,28 @@ from typing import Optional def prompt_template( - template: str, user_name: str = None, current_location: str = None + template: str, user_name: str = None, user_location: str = None ) -> str: # Get the current date current_date = datetime.now() # Format the date to YYYY-MM-DD formatted_date = current_date.strftime("%Y-%m-%d") + formatted_time = current_date.strftime("%I:%M:%S %p") - # Replace {{CURRENT_DATE}} in the template with the formatted date template = template.replace("{{CURRENT_DATE}}", formatted_date) + template = template.replace("{{CURRENT_TIME}}", formatted_time) + template = template.replace( + "{{CURRENT_DATETIME}}", f"{formatted_date} {formatted_time}" + ) if user_name: # Replace {{USER_NAME}} in the template with the user's name template = template.replace("{{USER_NAME}}", user_name) - if current_location: - # Replace {{CURRENT_LOCATION}} in the template with the current location - template = template.replace("{{CURRENT_LOCATION}}", current_location) + if user_location: + # Replace {{USER_LOCATION}} in the template with the current location + template = template.replace("{{USER_LOCATION}}", user_location) return template @@ -61,7 +65,7 @@ def title_generation_template( template = prompt_template( template, **( - {"user_name": user.get("name"), "current_location": user.get("location")} + {"user_name": user.get("name"), "user_location": user.get("location")} if user else {} ), @@ -104,7 +108,7 @@ def search_query_generation_template( template = prompt_template( template, **( - {"user_name": user.get("name"), "current_location": user.get("location")} + {"user_name": user.get("name"), "user_location": user.get("location")} if user else {} ), diff --git a/package-lock.json b/package-lock.json index f5b9d6a788..513993c74d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.3.4", + "version": "0.3.5", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.3.4", + "version": "0.3.5", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-python": "^6.1.6", diff --git a/package.json b/package.json index bf353ef7f4..46aeb14f77 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.3.4", + "version": "0.3.5", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", diff --git a/src/app.css b/src/app.css index da1d961e53..baf620845b 100644 --- a/src/app.css +++ b/src/app.css @@ -28,6 +28,10 @@ math { @apply rounded-lg; } +.markdown a { + @apply underline; +} + ol > li { counter-increment: list-number; display: block; diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index c40815611e..9558e98f50 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -205,6 +205,54 @@ export const generateTitle = async ( return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat'; }; +export const generateEmoji = async ( + token: string = '', + model: string, + prompt: string, + chat_id?: string +) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/task/emoji/completions`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + prompt: prompt, + ...(chat_id && { chat_id: chat_id }) + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } + return null; + }); + + if (error) { + throw error; + } + + const response = res?.choices[0]?.message?.content.replace(/["']/g, '') ?? null; + + if (response) { + if (/\p{Extended_Pictographic}/u.test(response)) { + return response.match(/\p{Extended_Pictographic}/gu)[0]; + } + } + + return null; +}; + export const generateSearchQuery = async ( token: string = '', model: string, diff --git a/src/lib/apis/memories/index.ts b/src/lib/apis/memories/index.ts index 44b24e2937..c3c122adf8 100644 --- a/src/lib/apis/memories/index.ts +++ b/src/lib/apis/memories/index.ts @@ -59,6 +59,37 @@ export const addNewMemory = async (token: string, content: string) => { return res; }; +export const updateMemoryById = async (token: string, id: string, content: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/memories/${id}/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + content: content + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const queryMemory = async (token: string, content: string) => { let error = null; diff --git a/src/lib/apis/users/index.ts b/src/lib/apis/users/index.ts index 4c97b00363..0b22b71715 100644 --- a/src/lib/apis/users/index.ts +++ b/src/lib/apis/users/index.ts @@ -1,4 +1,5 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; +import { getUserPosition } from '$lib/utils'; export const getUserPermissions = async (token: string) => { let error = null; @@ -198,6 +199,75 @@ export const getUserById = async (token: string, userId: string) => { return res; }; +export const getUserInfo = async (token: string) => { + let error = null; + const res = await fetch(`${WEBUI_API_BASE_URL}/users/user/info`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateUserInfo = async (token: string, info: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/users/user/info/update`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...info + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getAndUpdateUserLocation = async (token: string) => { + const location = await getUserPosition().catch((err) => { + throw err; + }); + + if (location) { + await updateUserInfo(token, { location: location }); + return location; + } else { + throw new Error('Failed to get user location'); + } +}; + export const deleteUserById = async (token: string, userId: string) => { let error = null; diff --git a/src/lib/components/admin/Settings/Connections.svelte b/src/lib/components/admin/Settings/Connections.svelte index 669fe8aae0..909a075812 100644 --- a/src/lib/components/admin/Settings/Connections.svelte +++ b/src/lib/components/admin/Settings/Connections.svelte @@ -44,6 +44,8 @@ let ENABLE_OLLAMA_API = null; const verifyOpenAIHandler = async (idx) => { + OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS.map((url) => url.replace(/\/$/, '')); + OPENAI_API_BASE_URLS = await updateOpenAIUrls(localStorage.token, OPENAI_API_BASE_URLS); OPENAI_API_KEYS = await updateOpenAIKeys(localStorage.token, OPENAI_API_KEYS); @@ -63,6 +65,10 @@ }; const verifyOllamaHandler = async (idx) => { + OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter((url) => url !== '').map((url) => + url.replace(/\/$/, '') + ); + OLLAMA_BASE_URLS = await updateOllamaUrls(localStorage.token, OLLAMA_BASE_URLS); const res = await getOllamaVersion(localStorage.token, idx).catch((error) => { @@ -78,6 +84,8 @@ }; const updateOpenAIHandler = async () => { + OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS.map((url) => url.replace(/\/$/, '')); + // Check if API KEYS length is same than API URLS length if (OPENAI_API_KEYS.length !== OPENAI_API_BASE_URLS.length) { // if there are more keys than urls, remove the extra keys @@ -100,7 +108,10 @@ }; const updateOllamaUrlsHandler = async () => { - OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter((url) => url !== ''); + OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter((url) => url !== '').map((url) => + url.replace(/\/$/, '') + ); + console.log(OLLAMA_BASE_URLS); if (OLLAMA_BASE_URLS.length === 0) { diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte index ab8996d925..af2dfcdc3e 100644 --- a/src/lib/components/admin/Settings/Documents.svelte +++ b/src/lib/components/admin/Settings/Documents.svelte @@ -13,6 +13,8 @@ getRAGConfig, updateRAGConfig } from '$lib/apis/rag'; + import ResetUploadDirConfirmDialog from '$lib/components/common/ConfirmDialog.svelte'; + import ResetVectorDBConfirmDialog from '$lib/components/common/ConfirmDialog.svelte'; import { documents, models } from '$lib/stores'; import { onMount, getContext } from 'svelte'; @@ -213,6 +215,34 @@ }); + { + const res = resetUploadDir(localStorage.token).catch((error) => { + toast.error(error); + return null; + }); + + if (res) { + toast.success($i18n.t('Success')); + } + }} +/> + + { + const res = resetVectorDB(localStorage.token).catch((error) => { + toast.error(error); + return null; + }); + + if (res) { + toast.success($i18n.t('Success')); + } + }} +/> +
{ @@ -640,199 +670,56 @@
- {#if showResetUploadDirConfirm} -
-
- - - - - {$i18n.t('Are you sure?')} -
- -
- - -
+ - {/if} +
{$i18n.t('Reset Upload Directory')}
+ - {#if showResetConfirm} -
-
- - - - {$i18n.t('Are you sure?')} -
- -
- - -
+ - {/if} +
{$i18n.t('Reset Vector Storage')}
+
diff --git a/src/lib/components/admin/Settings/Models.svelte b/src/lib/components/admin/Settings/Models.svelte index a943c6fb01..57d0be135a 100644 --- a/src/lib/components/admin/Settings/Models.svelte +++ b/src/lib/components/admin/Settings/Models.svelte @@ -1,5 +1,10 @@ + { + deleteModelHandler(); + }} +/> +
{#if ollamaEnabled} @@ -763,7 +773,7 @@
+ {:else if webConfig.search.engine === 'tavily'} +
+
+ {$i18n.t('Tavily API Key')} +
+ +
+
+ +
+
+
{/if}
{/if} diff --git a/src/lib/components/admin/UserChatsModal.svelte b/src/lib/components/admin/UserChatsModal.svelte index 67fa367cd1..535dee0740 100644 --- a/src/lib/components/admin/UserChatsModal.svelte +++ b/src/lib/components/admin/UserChatsModal.svelte @@ -31,6 +31,17 @@ } })(); } + + let sortKey = 'updated_at'; // default sort key + let sortOrder = 'desc'; // default sort order + function setSortKey(key) { + if (sortKey === key) { + sortOrder = sortOrder === 'asc' ? 'desc' : 'asc'; + } else { + sortKey = key; + sortOrder = 'asc'; + } + } @@ -69,18 +80,56 @@ class="text-xs text-gray-700 uppercase bg-transparent dark:text-gray-200 border-b-2 dark:border-gray-800" > - {$i18n.t('Name')} - {$i18n.t('Created at')} + setSortKey('title')} + > + {$i18n.t('Title')} + {#if sortKey === 'title'} + {sortOrder === 'asc' ? '▲' : '▼'} + {:else} + + {/if} + + setSortKey('created_at')} + > + {$i18n.t('Created at')} + {#if sortKey === 'created_at'} + {sortOrder === 'asc' ? '▲' : '▼'} + {:else} + + {/if} + + setSortKey('updated_at')} + > + {$i18n.t('Updated at')} + {#if sortKey === 'updated_at'} + {sortOrder === 'asc' ? '▲' : '▼'} + {:else} + + {/if} + - {#each chats as chat, idx} + {#each chats.sort((a, b) => { + if (a[sortKey] < b[sortKey]) return sortOrder === 'asc' ? -1 : 1; + if (a[sortKey] > b[sortKey]) return sortOrder === 'asc' ? 1 : -1; + return 0; + }) as chat, idx} - +
{chat.title} @@ -88,11 +137,16 @@ - +
{dayjs(chat.created_at * 1000).format($i18n.t('MMMM DD, YYYY HH:mm'))}
+ +
+ {dayjs(chat.updated_at * 1000).format($i18n.t('MMMM DD, YYYY HH:mm'))} +
+
diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 359056cfde..8819a0428a 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -30,6 +30,8 @@ import { convertMessagesToHistory, copyToClipboard, + extractSentencesForAudio, + getUserPosition, promptTemplate, splitStream } from '$lib/utils'; @@ -49,7 +51,7 @@ import { runWebSearch } from '$lib/apis/rag'; import { createOpenAITextStream } from '$lib/apis/streaming'; import { queryMemory } from '$lib/apis/memories'; - import { getUserSettings } from '$lib/apis/users'; + import { getAndUpdateUserLocation, getUserSettings } from '$lib/apis/users'; import { chatCompleted, generateTitle, generateSearchQuery } from '$lib/apis'; import Banner from '../common/Banner.svelte'; @@ -64,6 +66,8 @@ export let chatIdProp = ''; let loaded = false; + const eventTarget = new EventTarget(); + let stopResponseFlag = false; let autoScroll = true; let processing = ''; @@ -108,7 +112,8 @@ $: if (chatIdProp) { (async () => { - if (await loadChat()) { + console.log(chatIdProp); + if (chatIdProp && (await loadChat())) { await tick(); loaded = true; @@ -123,7 +128,11 @@ onMount(async () => { if (!$chatId) { - await initNewChat(); + chatId.subscribe(async (value) => { + if (!value) { + await initNewChat(); + } + }); } else { if (!($settings.saveChatHistory ?? true)) { await goto('/'); @@ -300,7 +309,7 @@ // Chat functions ////////////////////////// - const submitPrompt = async (userPrompt, _user = null) => { + const submitPrompt = async (userPrompt, { _raw = false } = {}) => { let _responses = []; console.log('submitPrompt', $chatId); @@ -344,7 +353,6 @@ parentId: messages.length !== 0 ? messages.at(-1).id : null, childrenIds: [], role: 'user', - user: _user ?? undefined, content: userPrompt, files: _files.length > 0 ? _files : undefined, timestamp: Math.floor(Date.now() / 1000), // Unix epoch @@ -362,15 +370,13 @@ // Wait until history/message have been updated await tick(); - - // Send prompt - _responses = await sendPrompt(userPrompt, userMessageId); + _responses = await sendPrompt(userPrompt, userMessageId, { newChat: true }); } return _responses; }; - const sendPrompt = async (prompt, parentId, modelId = null, newChat = true) => { + const sendPrompt = async (prompt, parentId, { modelId = null, newChat = false } = {}) => { let _responses = []; // If modelId is provided, use it, else use selected model @@ -490,7 +496,6 @@ responseMessage.userContext = userContext; const chatEventEmitter = await getChatEventEmitter(model.id, _chatId); - if (webSearchEnabled) { await getWebSearchResults(model.id, parentId, responseMessageId); } @@ -503,8 +508,6 @@ } _responses.push(_response); - console.log('chatEventEmitter', chatEventEmitter); - if (chatEventEmitter) clearInterval(chatEventEmitter); } else { toast.error($i18n.t(`Model {{modelId}} not found`, { modelId })); @@ -513,88 +516,9 @@ ); await chats.set(await getChatList(localStorage.token)); - return _responses; }; - const getWebSearchResults = async (model: string, parentId: string, responseId: string) => { - const responseMessage = history.messages[responseId]; - - responseMessage.statusHistory = [ - { - done: false, - action: 'web_search', - description: $i18n.t('Generating search query') - } - ]; - messages = messages; - - const prompt = history.messages[parentId].content; - let searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt).catch( - (error) => { - console.log(error); - return prompt; - } - ); - - if (!searchQuery) { - toast.warning($i18n.t('No search query generated')); - responseMessage.statusHistory.push({ - done: true, - error: true, - action: 'web_search', - description: 'No search query generated' - }); - - messages = messages; - } - - responseMessage.statusHistory.push({ - done: false, - action: 'web_search', - description: $i18n.t(`Searching "{{searchQuery}}"`, { searchQuery }) - }); - messages = messages; - - const results = await runWebSearch(localStorage.token, searchQuery).catch((error) => { - console.log(error); - toast.error(error); - - return null; - }); - - if (results) { - responseMessage.statusHistory.push({ - done: true, - action: 'web_search', - description: $i18n.t('Searched {{count}} sites', { count: results.filenames.length }), - query: searchQuery, - urls: results.filenames - }); - - if (responseMessage?.files ?? undefined === undefined) { - responseMessage.files = []; - } - - responseMessage.files.push({ - collection_name: results.collection_name, - name: searchQuery, - type: 'web_search_results', - urls: results.filenames - }); - - messages = messages; - } else { - responseMessage.statusHistory.push({ - done: true, - error: true, - action: 'web_search', - description: 'No search results found' - }); - messages = messages; - } - }; - const sendPromptOllama = async (model, userPrompt, responseMessageId, _chatId) => { let _response = null; @@ -610,7 +534,13 @@ $settings.system || (responseMessage?.userContext ?? null) ? { role: 'system', - content: `${promptTemplate($settings?.system ?? '', $user.name)}${ + content: `${promptTemplate( + $settings?.system ?? '', + $user.name, + $settings?.userLocation + ? await getAndUpdateUserLocation(localStorage.token) + : undefined + )}${ responseMessage?.userContext ?? null ? `\n\nUser Context:\n${(responseMessage?.userContext ?? []).join('\n')}` : '' @@ -676,6 +606,16 @@ array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index ); + eventTarget.dispatchEvent( + new CustomEvent('chat:start', { + detail: { + id: responseMessageId + } + }) + ); + + await tick(); + const [res, controller] = await generateChatCompletion(localStorage.token, { model: model.id, messages: messagesBody, @@ -745,6 +685,23 @@ continue; } else { responseMessage.content += data.message.content; + + const sentences = extractSentencesForAudio(responseMessage.content); + sentences.pop(); + + // dispatch only last sentence and make sure it hasn't been dispatched before + if ( + sentences.length > 0 && + sentences[sentences.length - 1] !== responseMessage.lastSentence + ) { + responseMessage.lastSentence = sentences[sentences.length - 1]; + eventTarget.dispatchEvent( + new CustomEvent('chat', { + detail: { id: responseMessageId, content: sentences[sentences.length - 1] } + }) + ); + } + messages = messages; } } else { @@ -771,21 +728,13 @@ messages = messages; if ($settings.notificationEnabled && !document.hasFocus()) { - const notification = new Notification( - selectedModelfile - ? `${ - selectedModelfile.title.charAt(0).toUpperCase() + - selectedModelfile.title.slice(1) - }` - : `${model.id}`, - { - body: responseMessage.content, - icon: selectedModelfile?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png` - } - ); + const notification = new Notification(`${model.id}`, { + body: responseMessage.content, + icon: `${WEBUI_BASE_URL}/static/favicon.png` + }); } - if ($settings.responseAutoCopy) { + if ($settings?.responseAutoCopy ?? false) { copyToClipboard(responseMessage.content); } @@ -847,6 +796,23 @@ stopResponseFlag = false; await tick(); + let lastSentence = extractSentencesForAudio(responseMessage.content)?.at(-1) ?? ''; + if (lastSentence) { + eventTarget.dispatchEvent( + new CustomEvent('chat', { + detail: { id: responseMessageId, content: lastSentence } + }) + ); + } + eventTarget.dispatchEvent( + new CustomEvent('chat:finish', { + detail: { + id: responseMessageId, + content: responseMessage.content + } + }) + ); + if (autoScroll) { scrollToBottom(); } @@ -887,6 +853,15 @@ scrollToBottom(); + eventTarget.dispatchEvent( + new CustomEvent('chat:start', { + detail: { + id: responseMessageId + } + }) + ); + await tick(); + try { const [res, controller] = await generateOpenAIChatCompletion( localStorage.token, @@ -903,7 +878,13 @@ $settings.system || (responseMessage?.userContext ?? null) ? { role: 'system', - content: `${promptTemplate($settings?.system ?? '', $user.name)}${ + content: `${promptTemplate( + $settings?.system ?? '', + $user.name, + $settings?.userLocation + ? await getAndUpdateUserLocation(localStorage.token) + : undefined + )}${ responseMessage?.userContext ?? null ? `\n\nUser Context:\n${(responseMessage?.userContext ?? []).join('\n')}` : '' @@ -1007,6 +988,23 @@ continue; } else { responseMessage.content += value; + + const sentences = extractSentencesForAudio(responseMessage.content); + sentences.pop(); + + // dispatch only last sentence and make sure it hasn't been dispatched before + if ( + sentences.length > 0 && + sentences[sentences.length - 1] !== responseMessage.lastSentence + ) { + responseMessage.lastSentence = sentences[sentences.length - 1]; + eventTarget.dispatchEvent( + new CustomEvent('chat', { + detail: { id: responseMessageId, content: sentences[sentences.length - 1] } + }) + ); + } + messages = messages; } @@ -1057,6 +1055,24 @@ stopResponseFlag = false; await tick(); + let lastSentence = extractSentencesForAudio(responseMessage.content)?.at(-1) ?? ''; + if (lastSentence) { + eventTarget.dispatchEvent( + new CustomEvent('chat', { + detail: { id: responseMessageId, content: lastSentence } + }) + ); + } + + eventTarget.dispatchEvent( + new CustomEvent('chat:finish', { + detail: { + id: responseMessageId, + content: responseMessage.content + } + }) + ); + if (autoScroll) { scrollToBottom(); } @@ -1123,9 +1139,12 @@ let userPrompt = userMessage.content; if ((userMessage?.models ?? [...selectedModels]).length == 1) { - await sendPrompt(userPrompt, userMessage.id, undefined, false); + // If user message has only one model selected, sendPrompt automatically selects it for regeneration + await sendPrompt(userPrompt, userMessage.id); } else { - await sendPrompt(userPrompt, userMessage.id, message.model, false); + // If there are multiple models selected, use the model of the response message for regeneration + // e.g. many model chat + await sendPrompt(userPrompt, userMessage.id, { modelId: message.model }); } } }; @@ -1191,6 +1210,84 @@ } }; + const getWebSearchResults = async (model: string, parentId: string, responseId: string) => { + const responseMessage = history.messages[responseId]; + + responseMessage.statusHistory = [ + { + done: false, + action: 'web_search', + description: $i18n.t('Generating search query') + } + ]; + messages = messages; + + const prompt = history.messages[parentId].content; + let searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt).catch( + (error) => { + console.log(error); + return prompt; + } + ); + + if (!searchQuery) { + toast.warning($i18n.t('No search query generated')); + responseMessage.statusHistory.push({ + done: true, + error: true, + action: 'web_search', + description: 'No search query generated' + }); + + messages = messages; + } + + responseMessage.statusHistory.push({ + done: false, + action: 'web_search', + description: $i18n.t(`Searching "{{searchQuery}}"`, { searchQuery }) + }); + messages = messages; + + const results = await runWebSearch(localStorage.token, searchQuery).catch((error) => { + console.log(error); + toast.error(error); + + return null; + }); + + if (results) { + responseMessage.statusHistory.push({ + done: true, + action: 'web_search', + description: $i18n.t('Searched {{count}} sites', { count: results.filenames.length }), + query: searchQuery, + urls: results.filenames + }); + + if (responseMessage?.files ?? undefined === undefined) { + responseMessage.files = []; + } + + responseMessage.files.push({ + collection_name: results.collection_name, + name: searchQuery, + type: 'web_search_results', + urls: results.filenames + }); + + messages = messages; + } else { + responseMessage.statusHistory.push({ + done: true, + error: true, + action: 'web_search', + description: 'No search results found' + }); + messages = messages; + } + }; + const getTags = async () => { return await getTagsById(localStorage.token, $chatId).catch(async (error) => { return []; @@ -1206,7 +1303,18 @@ - +