diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index ae2e2529c3..7f603cb10c 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -73,4 +73,4 @@ ### Contributor License Agreement -By submitting this pull request, I confirm that I have read and fully agree to the [CONTRIBUTOR_LICENSE_AGREEMENT](CONTRIBUTOR_LICENSE_AGREEMENT), and I am providing my contributions under its terms. \ No newline at end of file +By submitting this pull request, I confirm that I have read and fully agree to the [Contributor License Agreement (CLA)](/CONTRIBUTOR_LICENSE_AGREEMENT), and I am providing my contributions under its terms. diff --git a/LICENSE b/LICENSE index 89109d7516..3991050972 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2023-2025 Timothy Jaeryang Baek +Copyright (c) 2023-2025 Timothy Jaeryang Baek (Open WebUI) All rights reserved. Redistribution and use in source and binary forms, with or without @@ -15,6 +15,12 @@ modification, are permitted provided that the following conditions are met: contributors may be used to endorse or promote products derived from this software without specific prior written permission. +4. Notwithstanding any other provision of this License, and as a material condition of the rights granted herein, licensees are strictly prohibited from altering, removing, obscuring, or replacing any "Open WebUI" branding, including but not limited to the name, logo, or any visual, textual, or symbolic identifiers that distinguish the software and its interfaces, in any deployment or distribution, regardless of the number of users, except as explicitly set forth in Clauses 5 and 6 below. + +5. The branding restriction enumerated in Clause 4 shall not apply in the following limited circumstances: (i) deployments or distributions where the total number of end users (defined as individual natural persons with direct access to the application) does not exceed fifty (50) within any rolling thirty (30) day period; (ii) cases in which the licensee is an official contributor to the codebase—with a substantive code change successfully merged into the main branch of the official codebase maintained by the copyright holder—who has obtained specific prior written permission for branding adjustment from the copyright holder; or (iii) where the licensee has obtained a duly executed enterprise license expressly permitting such modification. For all other cases, any removal or alteration of the "Open WebUI" branding shall constitute a material breach of license. + +6. All code, modifications, or derivative works incorporated into this project prior to the incorporation of this branding clause remain licensed under the BSD 3-Clause License, and prior contributors retain all BSD-3 rights therein; if any such contributor requests the removal of their BSD-3-licensed code, the copyright holder will do so, and any replacement code will be licensed under the project's primary license then in effect. By contributing after this clause's adoption, you agree to the project's Contributor License Agreement (CLA) and to these updated terms for all new contributions. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE diff --git a/README.md b/README.md index 54ad41503d..15f1731592 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,6 @@ ![GitHub language count](https://img.shields.io/github/languages/count/open-webui/open-webui) ![GitHub top language](https://img.shields.io/github/languages/top/open-webui/open-webui) ![GitHub last commit](https://img.shields.io/github/last-commit/open-webui/open-webui?color=red) -![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Follama-webui%2Follama-wbui&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false) [![Discord](https://img.shields.io/badge/Discord-Open_WebUI-blue?logo=discord&logoColor=white)](https://discord.gg/5rJgQTnV4s) [![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/tjbck) @@ -206,7 +205,7 @@ Discover upcoming features on our roadmap in the [Open WebUI Documentation](http ## License 📜 -This project is licensed under the [BSD-3-Clause License](LICENSE) - see the [LICENSE](LICENSE) file for details. 📄 +This project is licensed under the [Open WebUI License](LICENSE), a revised BSD-3-Clause license. You receive all the same rights as the classic BSD-3 license: you can use, modify, and distribute the software, including in proprietary and commercial products, with minimal restrictions. The only additional requirement is to preserve the "Open WebUI" branding, as detailed in the LICENSE file. For full terms, see the [LICENSE](LICENSE) document. 📄 ## Support 💬 diff --git a/backend/open_webui/__init__.py b/backend/open_webui/__init__.py index ff386957c4..967a49de8f 100644 --- a/backend/open_webui/__init__.py +++ b/backend/open_webui/__init__.py @@ -76,7 +76,7 @@ def serve( from open_webui.env import UVICORN_WORKERS # Import the workers setting uvicorn.run( - open_webui.main.app, + "open_webui.main:app", host=host, port=port, forwarded_allow_ips="*", diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 3b40977f27..387232797e 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -509,6 +509,12 @@ ENABLE_OAUTH_GROUP_MANAGEMENT = PersistentConfig( os.environ.get("ENABLE_OAUTH_GROUP_MANAGEMENT", "False").lower() == "true", ) +ENABLE_OAUTH_GROUP_CREATION = PersistentConfig( + "ENABLE_OAUTH_GROUP_CREATION", + "oauth.enable_group_creation", + os.environ.get("ENABLE_OAUTH_GROUP_CREATION", "False").lower() == "true", +) + OAUTH_ROLES_CLAIM = PersistentConfig( "OAUTH_ROLES_CLAIM", "oauth.roles_claim", @@ -952,10 +958,13 @@ DEFAULT_MODELS = PersistentConfig( "DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None) ) -DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig( - "DEFAULT_PROMPT_SUGGESTIONS", - "ui.prompt_suggestions", - [ +try: + default_prompt_suggestions = json.loads(os.environ.get("DEFAULT_PROMPT_SUGGESTIONS", "[]")) +except Exception as e: + log.exception(f"Error loading DEFAULT_PROMPT_SUGGESTIONS: {e}") + default_prompt_suggestions = [] +if default_prompt_suggestions == []: + default_prompt_suggestions = [ { "title": ["Help me study", "vocabulary for a college entrance exam"], "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.", @@ -983,7 +992,11 @@ DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig( "title": ["Overcome procrastination", "give me tips"], "content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?", }, - ], + ] +DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig( + "DEFAULT_PROMPT_SUGGESTIONS", + "ui.prompt_suggestions", + default_prompt_suggestions, ) MODEL_ORDER_LIST = PersistentConfig( @@ -1062,6 +1075,14 @@ USER_PERMISSIONS_CHAT_EDIT = ( os.environ.get("USER_PERMISSIONS_CHAT_EDIT", "True").lower() == "true" ) +USER_PERMISSIONS_CHAT_SHARE = ( + os.environ.get("USER_PERMISSIONS_CHAT_SHARE", "True").lower() == "true" +) + +USER_PERMISSIONS_CHAT_EXPORT = ( + os.environ.get("USER_PERMISSIONS_CHAT_EXPORT", "True").lower() == "true" +) + USER_PERMISSIONS_CHAT_STT = ( os.environ.get("USER_PERMISSIONS_CHAT_STT", "True").lower() == "true" ) @@ -1126,6 +1147,8 @@ DEFAULT_USER_PERMISSIONS = { "file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD, "delete": USER_PERMISSIONS_CHAT_DELETE, "edit": USER_PERMISSIONS_CHAT_EDIT, + "share": USER_PERMISSIONS_CHAT_SHARE, + "export": USER_PERMISSIONS_CHAT_EXPORT, "stt": USER_PERMISSIONS_CHAT_STT, "tts": USER_PERMISSIONS_CHAT_TTS, "call": USER_PERMISSIONS_CHAT_CALL, @@ -1203,6 +1226,9 @@ ENABLE_USER_WEBHOOKS = PersistentConfig( os.environ.get("ENABLE_USER_WEBHOOKS", "True").lower() == "true", ) +# FastAPI / AnyIO settings +THREAD_POOL_SIZE = int(os.getenv("THREAD_POOL_SIZE", "0")) + def validate_cors_origins(origins): for origin in origins: @@ -1693,6 +1719,9 @@ MILVUS_TOKEN = os.environ.get("MILVUS_TOKEN", None) # Qdrant QDRANT_URI = os.environ.get("QDRANT_URI", None) QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None) +QDRANT_ON_DISK = os.environ.get("QDRANT_ON_DISK", "false").lower() == "true" +QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "False").lower() == "true" +QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334")) # OpenSearch OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200") @@ -1724,6 +1753,14 @@ PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int( os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536") ) +# Pinecone +PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) +PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) +PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "open-webui-index") +PINECONE_DIMENSION = int(os.getenv("PINECONE_DIMENSION", 1536)) # or 3072, 1024, 768 +PINECONE_METRIC = os.getenv("PINECONE_METRIC", "cosine") +PINECONE_CLOUD = os.getenv("PINECONE_CLOUD", "aws") # or "gcp" or "azure" + #################################### # Information Retrieval (RAG) #################################### @@ -1760,6 +1797,13 @@ ONEDRIVE_CLIENT_ID = PersistentConfig( os.environ.get("ONEDRIVE_CLIENT_ID", ""), ) +ONEDRIVE_SHAREPOINT_URL = PersistentConfig( + "ONEDRIVE_SHAREPOINT_URL", + "onedrive.sharepoint_url", + os.environ.get("ONEDRIVE_SHAREPOINT_URL", ""), +) + + # RAG Content Extraction CONTENT_EXTRACTION_ENGINE = PersistentConfig( "CONTENT_EXTRACTION_ENGINE", @@ -2251,6 +2295,29 @@ FIRECRAWL_API_BASE_URL = PersistentConfig( os.environ.get("FIRECRAWL_API_BASE_URL", "https://api.firecrawl.dev"), ) +EXTERNAL_WEB_SEARCH_URL = PersistentConfig( + "EXTERNAL_WEB_SEARCH_URL", + "rag.web.search.external_web_search_url", + os.environ.get("EXTERNAL_WEB_SEARCH_URL", ""), +) + +EXTERNAL_WEB_SEARCH_API_KEY = PersistentConfig( + "EXTERNAL_WEB_SEARCH_API_KEY", + "rag.web.search.external_web_search_api_key", + os.environ.get("EXTERNAL_WEB_SEARCH_API_KEY", ""), +) + +EXTERNAL_WEB_LOADER_URL = PersistentConfig( + "EXTERNAL_WEB_LOADER_URL", + "rag.web.loader.external_web_loader_url", + os.environ.get("EXTERNAL_WEB_LOADER_URL", ""), +) + +EXTERNAL_WEB_LOADER_API_KEY = PersistentConfig( + "EXTERNAL_WEB_LOADER_API_KEY", + "rag.web.loader.external_web_loader_api_key", + os.environ.get("EXTERNAL_WEB_LOADER_API_KEY", ""), +) #################################### # Images diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index c9d71a4a03..59557349e3 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -354,6 +354,10 @@ BYPASS_MODEL_ACCESS_CONTROL = ( os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true" ) +WEBUI_AUTH_SIGNOUT_REDIRECT_URL = os.environ.get( + "WEBUI_AUTH_SIGNOUT_REDIRECT_URL", None +) + #################################### # WEBUI_SECRET_KEY #################################### @@ -409,6 +413,11 @@ else: except Exception: AIOHTTP_CLIENT_TIMEOUT = 300 + +AIOHTTP_CLIENT_SESSION_SSL = ( + os.environ.get("AIOHTTP_CLIENT_SESSION_SSL", "True").lower() == "true" +) + AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = os.environ.get( "AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST", os.environ.get("AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "10"), @@ -437,6 +446,56 @@ else: except Exception: AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = 10 + +AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL = ( + os.environ.get("AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL", "True").lower() == "true" +) + + +#################################### +# SENTENCE TRANSFORMERS +#################################### + + +SENTENCE_TRANSFORMERS_BACKEND = os.environ.get("SENTENCE_TRANSFORMERS_BACKEND", "") +if SENTENCE_TRANSFORMERS_BACKEND == "": + SENTENCE_TRANSFORMERS_BACKEND = "torch" + + +SENTENCE_TRANSFORMERS_MODEL_KWARGS = os.environ.get( + "SENTENCE_TRANSFORMERS_MODEL_KWARGS", "" +) +if SENTENCE_TRANSFORMERS_MODEL_KWARGS == "": + SENTENCE_TRANSFORMERS_MODEL_KWARGS = None +else: + try: + SENTENCE_TRANSFORMERS_MODEL_KWARGS = json.loads( + SENTENCE_TRANSFORMERS_MODEL_KWARGS + ) + except Exception: + SENTENCE_TRANSFORMERS_MODEL_KWARGS = None + + +SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = os.environ.get( + "SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND", "" +) +if SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND == "": + SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = "torch" + + +SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = os.environ.get( + "SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS", "" +) +if SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS == "": + SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None +else: + try: + SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = json.loads( + SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS + ) + except Exception: + SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None + #################################### # OFFLINE_MODE #################################### @@ -446,6 +505,7 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true" if OFFLINE_MODE: os.environ["HF_HUB_OFFLINE"] = "1" + #################################### # AUDIT LOGGING #################################### @@ -467,6 +527,7 @@ AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders" AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS] AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS] + #################################### # OPENTELEMETRY #################################### diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 56ea17fa18..22a02e1491 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -17,6 +17,7 @@ from sqlalchemy import text from typing import Optional from aiocache import cached import aiohttp +import anyio.to_thread import requests @@ -100,11 +101,14 @@ from open_webui.config import ( # OpenAI ENABLE_OPENAI_API, ONEDRIVE_CLIENT_ID, + ONEDRIVE_SHAREPOINT_URL, OPENAI_API_BASE_URLS, OPENAI_API_KEYS, OPENAI_API_CONFIGS, # Direct Connections ENABLE_DIRECT_CONNECTIONS, + # Thread pool size for FastAPI/AnyIO + THREAD_POOL_SIZE, # Tool Server Configs TOOL_SERVER_CONNECTIONS, # Code Execution @@ -240,12 +244,17 @@ from open_webui.config import ( GOOGLE_DRIVE_CLIENT_ID, GOOGLE_DRIVE_API_KEY, ONEDRIVE_CLIENT_ID, + ONEDRIVE_SHAREPOINT_URL, ENABLE_RAG_HYBRID_SEARCH, ENABLE_RAG_LOCAL_WEB_FETCH, ENABLE_WEB_LOADER_SSL_VERIFICATION, ENABLE_GOOGLE_DRIVE_INTEGRATION, ENABLE_ONEDRIVE_INTEGRATION, UPLOAD_DIR, + EXTERNAL_WEB_SEARCH_URL, + EXTERNAL_WEB_SEARCH_API_KEY, + EXTERNAL_WEB_LOADER_URL, + EXTERNAL_WEB_LOADER_API_KEY, # WebUI WEBUI_AUTH, WEBUI_NAME, @@ -341,6 +350,7 @@ from open_webui.env import ( WEBUI_SESSION_COOKIE_SECURE, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER, + WEBUI_AUTH_SIGNOUT_REDIRECT_URL, ENABLE_WEBSOCKET_SUPPORT, BYPASS_MODEL_ACCESS_CONTROL, RESET_CONFIG_ON_START, @@ -432,6 +442,11 @@ async def lifespan(app: FastAPI): if LICENSE_KEY: get_license_data(app, LICENSE_KEY) + pool_size = THREAD_POOL_SIZE + if pool_size and pool_size > 0: + limiter = anyio.to_thread.current_default_thread_limiter() + limiter.total_tokens = pool_size + asyncio.create_task(periodic_usage_pool_cleanup()) yield @@ -576,6 +591,7 @@ app.state.config.LDAP_CIPHERS = LDAP_CIPHERS app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER +app.state.WEBUI_AUTH_SIGNOUT_REDIRECT_URL = WEBUI_AUTH_SIGNOUT_REDIRECT_URL app.state.EXTERNAL_PWA_MANIFEST_URL = EXTERNAL_PWA_MANIFEST_URL app.state.USER_COUNT = None @@ -668,6 +684,10 @@ app.state.config.EXA_API_KEY = EXA_API_KEY app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY app.state.config.SOUGOU_API_SID = SOUGOU_API_SID app.state.config.SOUGOU_API_SK = SOUGOU_API_SK +app.state.config.EXTERNAL_WEB_SEARCH_URL = EXTERNAL_WEB_SEARCH_URL +app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = EXTERNAL_WEB_SEARCH_API_KEY +app.state.config.EXTERNAL_WEB_LOADER_URL = EXTERNAL_WEB_LOADER_URL +app.state.config.EXTERNAL_WEB_LOADER_API_KEY = EXTERNAL_WEB_LOADER_API_KEY app.state.config.PLAYWRIGHT_WS_URL = PLAYWRIGHT_WS_URL @@ -1327,7 +1347,10 @@ async def get_app_config(request: Request): "client_id": GOOGLE_DRIVE_CLIENT_ID.value, "api_key": GOOGLE_DRIVE_API_KEY.value, }, - "onedrive": {"client_id": ONEDRIVE_CLIENT_ID.value}, + "onedrive": { + "client_id": ONEDRIVE_CLIENT_ID.value, + "sharepoint_url": ONEDRIVE_SHAREPOINT_URL.value, + }, "license_metadata": app.state.LICENSE_METADATA, **( { @@ -1439,7 +1462,7 @@ async def get_manifest_json(): "start_url": "/", "display": "standalone", "background_color": "#343541", - "orientation": "natural", + "orientation": "any", "icons": [ { "src": "/static/logo.png", diff --git a/backend/open_webui/retrieval/loaders/external.py b/backend/open_webui/retrieval/loaders/external.py new file mode 100644 index 0000000000..642cfd3a5e --- /dev/null +++ b/backend/open_webui/retrieval/loaders/external.py @@ -0,0 +1,53 @@ +import requests +import logging +from typing import Iterator, List, Union + +from langchain_core.document_loaders import BaseLoader +from langchain_core.documents import Document +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +class ExternalLoader(BaseLoader): + def __init__( + self, + web_paths: Union[str, List[str]], + external_url: str, + external_api_key: str, + continue_on_failure: bool = True, + **kwargs, + ) -> None: + self.external_url = external_url + self.external_api_key = external_api_key + self.urls = web_paths if isinstance(web_paths, list) else [web_paths] + self.continue_on_failure = continue_on_failure + + def lazy_load(self) -> Iterator[Document]: + batch_size = 20 + for i in range(0, len(self.urls), batch_size): + urls = self.urls[i : i + batch_size] + try: + response = requests.post( + self.external_url, + headers={ + "User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot", + "Authorization": f"Bearer {self.external_api_key}", + }, + json={ + "urls": urls, + }, + ) + response.raise_for_status() + results = response.json() + for result in results: + yield Document( + page_content=result.get("page_content", ""), + metadata=result.get("metadata", {}), + ) + except Exception as e: + if self.continue_on_failure: + log.error(f"Error extracting content from batch {urls}: {e}") + else: + raise e diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 2b23cad17f..410945c810 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -260,23 +260,47 @@ def query_collection( k: int, ) -> dict: results = [] - for query in queries: - log.debug(f"query_collection:query {query}") - query_embedding = embedding_function(query, prefix=RAG_EMBEDDING_QUERY_PREFIX) - for collection_name in collection_names: + error = False + + def process_query_collection(collection_name, query_embedding): + try: if collection_name: - try: - result = query_doc( - collection_name=collection_name, - k=k, - query_embedding=query_embedding, - ) - if result is not None: - results.append(result.model_dump()) - except Exception as e: - log.exception(f"Error when querying the collection: {e}") - else: - pass + result = query_doc( + collection_name=collection_name, + k=k, + query_embedding=query_embedding, + ) + if result is not None: + return result.model_dump(), None + return None, None + except Exception as e: + log.exception(f"Error when querying the collection: {e}") + return None, e + + # Generate all query embeddings (in one call) + query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX) + log.debug( + f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections" + ) + + with ThreadPoolExecutor() as executor: + future_results = [] + for query_embedding in query_embeddings: + for collection_name in collection_names: + result = executor.submit( + process_query_collection, collection_name, query_embedding + ) + future_results.append(result) + task_results = [future.result() for future in future_results] + + for result, err in task_results: + if err is not None: + error = True + elif result is not None: + results.append(result) + + if error and not results: + log.warning("All collection queries failed. No results returned.") return merge_and_sort_query_results(results, k=k) diff --git a/backend/open_webui/retrieval/vector/connector.py b/backend/open_webui/retrieval/vector/connector.py index ac8884c043..198e6f1761 100644 --- a/backend/open_webui/retrieval/vector/connector.py +++ b/backend/open_webui/retrieval/vector/connector.py @@ -20,6 +20,10 @@ elif VECTOR_DB == "elasticsearch": from open_webui.retrieval.vector.dbs.elasticsearch import ElasticsearchClient VECTOR_DB_CLIENT = ElasticsearchClient() +elif VECTOR_DB == "pinecone": + from open_webui.retrieval.vector.dbs.pinecone import PineconeClient + + VECTOR_DB_CLIENT = PineconeClient() else: from open_webui.retrieval.vector.dbs.chroma import ChromaClient diff --git a/backend/open_webui/retrieval/vector/dbs/chroma.py b/backend/open_webui/retrieval/vector/dbs/chroma.py index a6b97df3e9..f9adc9c95f 100755 --- a/backend/open_webui/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/retrieval/vector/dbs/chroma.py @@ -5,7 +5,12 @@ from chromadb.utils.batch_utils import create_batches from typing import Optional -from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + SearchResult, + GetResult, +) from open_webui.config import ( CHROMA_DATA_PATH, CHROMA_HTTP_HOST, @@ -23,7 +28,7 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -class ChromaClient: +class ChromaClient(VectorDBBase): def __init__(self): settings_dict = { "allow_reset": True, diff --git a/backend/open_webui/retrieval/vector/dbs/elasticsearch.py b/backend/open_webui/retrieval/vector/dbs/elasticsearch.py index c896284946..18a915e381 100644 --- a/backend/open_webui/retrieval/vector/dbs/elasticsearch.py +++ b/backend/open_webui/retrieval/vector/dbs/elasticsearch.py @@ -2,7 +2,12 @@ from elasticsearch import Elasticsearch, BadRequestError from typing import Optional import ssl from elasticsearch.helpers import bulk, scan -from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + SearchResult, + GetResult, +) from open_webui.config import ( ELASTICSEARCH_URL, ELASTICSEARCH_CA_CERTS, @@ -15,7 +20,7 @@ from open_webui.config import ( ) -class ElasticsearchClient: +class ElasticsearchClient(VectorDBBase): """ Important: in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating diff --git a/backend/open_webui/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py index 26b4dd5ed2..f116c57f79 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus.py @@ -4,7 +4,12 @@ import json import logging from typing import Optional -from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + SearchResult, + GetResult, +) from open_webui.config import ( MILVUS_URI, MILVUS_DB, @@ -16,7 +21,7 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -class MilvusClient: +class MilvusClient(VectorDBBase): def __init__(self): self.collection_prefix = "open_webui" if MILVUS_TOKEN is None: diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py index 432bcef412..60ef2d906c 100644 --- a/backend/open_webui/retrieval/vector/dbs/opensearch.py +++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py @@ -2,7 +2,12 @@ from opensearchpy import OpenSearch from opensearchpy.helpers import bulk from typing import Optional -from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + SearchResult, + GetResult, +) from open_webui.config import ( OPENSEARCH_URI, OPENSEARCH_SSL, @@ -12,7 +17,7 @@ from open_webui.config import ( ) -class OpenSearchClient: +class OpenSearchClient(VectorDBBase): def __init__(self): self.index_prefix = "open_webui" self.client = OpenSearch( diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index c38dbb0367..cd875b4064 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -22,7 +22,12 @@ from pgvector.sqlalchemy import Vector from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.exc import NoSuchTableError -from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + SearchResult, + GetResult, +) from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH from open_webui.env import SRC_LOG_LEVELS @@ -44,7 +49,7 @@ class DocumentChunk(Base): vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True) -class PgvectorClient: +class PgvectorClient(VectorDBBase): def __init__(self) -> None: # if no pgvector uri, use the existing database connection diff --git a/backend/open_webui/retrieval/vector/dbs/pinecone.py b/backend/open_webui/retrieval/vector/dbs/pinecone.py new file mode 100644 index 0000000000..bc9bd8bc36 --- /dev/null +++ b/backend/open_webui/retrieval/vector/dbs/pinecone.py @@ -0,0 +1,412 @@ +from typing import Optional, List, Dict, Any, Union +import logging +from pinecone import Pinecone, ServerlessSpec + +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + SearchResult, + GetResult, +) +from open_webui.config import ( + PINECONE_API_KEY, + PINECONE_ENVIRONMENT, + PINECONE_INDEX_NAME, + PINECONE_DIMENSION, + PINECONE_METRIC, + PINECONE_CLOUD, +) +from open_webui.env import SRC_LOG_LEVELS + +NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system +BATCH_SIZE = 100 # Recommended batch size for Pinecone operations + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +class PineconeClient(VectorDBBase): + def __init__(self): + self.collection_prefix = "open-webui" + + # Validate required configuration + self._validate_config() + + # Store configuration values + self.api_key = PINECONE_API_KEY + self.environment = PINECONE_ENVIRONMENT + self.index_name = PINECONE_INDEX_NAME + self.dimension = PINECONE_DIMENSION + self.metric = PINECONE_METRIC + self.cloud = PINECONE_CLOUD + + # Initialize Pinecone client + self.client = Pinecone(api_key=self.api_key) + + # Create index if it doesn't exist + self._initialize_index() + + def _validate_config(self) -> None: + """Validate that all required configuration variables are set.""" + missing_vars = [] + if not PINECONE_API_KEY: + missing_vars.append("PINECONE_API_KEY") + if not PINECONE_ENVIRONMENT: + missing_vars.append("PINECONE_ENVIRONMENT") + if not PINECONE_INDEX_NAME: + missing_vars.append("PINECONE_INDEX_NAME") + if not PINECONE_DIMENSION: + missing_vars.append("PINECONE_DIMENSION") + if not PINECONE_CLOUD: + missing_vars.append("PINECONE_CLOUD") + + if missing_vars: + raise ValueError( + f"Required configuration missing: {', '.join(missing_vars)}" + ) + + def _initialize_index(self) -> None: + """Initialize the Pinecone index.""" + try: + # Check if index exists + if self.index_name not in self.client.list_indexes().names(): + log.info(f"Creating Pinecone index '{self.index_name}'...") + self.client.create_index( + name=self.index_name, + dimension=self.dimension, + metric=self.metric, + spec=ServerlessSpec(cloud=self.cloud, region=self.environment), + ) + log.info(f"Successfully created Pinecone index '{self.index_name}'") + else: + log.info(f"Using existing Pinecone index '{self.index_name}'") + + # Connect to the index + self.index = self.client.Index(self.index_name) + + except Exception as e: + log.error(f"Failed to initialize Pinecone index: {e}") + raise RuntimeError(f"Failed to initialize Pinecone index: {e}") + + def _create_points( + self, items: List[VectorItem], collection_name_with_prefix: str + ) -> List[Dict[str, Any]]: + """Convert VectorItem objects to Pinecone point format.""" + points = [] + for item in items: + # Start with any existing metadata or an empty dict + metadata = item.get("metadata", {}).copy() if item.get("metadata") else {} + + # Add text to metadata if available + if "text" in item: + metadata["text"] = item["text"] + + # Always add collection_name to metadata for filtering + metadata["collection_name"] = collection_name_with_prefix + + point = { + "id": item["id"], + "values": item["vector"], + "metadata": metadata, + } + points.append(point) + return points + + def _get_collection_name_with_prefix(self, collection_name: str) -> str: + """Get the collection name with prefix.""" + return f"{self.collection_prefix}_{collection_name}" + + def _normalize_distance(self, score: float) -> float: + """Normalize distance score based on the metric used.""" + if self.metric.lower() == "cosine": + # Cosine similarity ranges from -1 to 1, normalize to 0 to 1 + return (score + 1.0) / 2.0 + elif self.metric.lower() in ["euclidean", "dotproduct"]: + # These are already suitable for ranking (smaller is better for Euclidean) + return score + else: + # For other metrics, use as is + return score + + def _result_to_get_result(self, matches: list) -> GetResult: + """Convert Pinecone matches to GetResult format.""" + ids = [] + documents = [] + metadatas = [] + + for match in matches: + metadata = match.get("metadata", {}) + ids.append(match["id"]) + documents.append(metadata.get("text", "")) + metadatas.append(metadata) + + return GetResult( + **{ + "ids": [ids], + "documents": [documents], + "metadatas": [metadatas], + } + ) + + def has_collection(self, collection_name: str) -> bool: + """Check if a collection exists by searching for at least one item.""" + collection_name_with_prefix = self._get_collection_name_with_prefix( + collection_name + ) + + try: + # Search for at least 1 item with this collection name in metadata + response = self.index.query( + vector=[0.0] * self.dimension, # dummy vector + top_k=1, + filter={"collection_name": collection_name_with_prefix}, + include_metadata=False, + ) + return len(response.matches) > 0 + except Exception as e: + log.exception( + f"Error checking collection '{collection_name_with_prefix}': {e}" + ) + return False + + def delete_collection(self, collection_name: str) -> None: + """Delete a collection by removing all vectors with the collection name in metadata.""" + collection_name_with_prefix = self._get_collection_name_with_prefix( + collection_name + ) + try: + self.index.delete(filter={"collection_name": collection_name_with_prefix}) + log.info( + f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)." + ) + except Exception as e: + log.warning( + f"Failed to delete collection '{collection_name_with_prefix}': {e}" + ) + raise + + def insert(self, collection_name: str, items: List[VectorItem]) -> None: + """Insert vectors into a collection.""" + if not items: + log.warning("No items to insert") + return + + collection_name_with_prefix = self._get_collection_name_with_prefix( + collection_name + ) + points = self._create_points(items, collection_name_with_prefix) + + # Insert in batches for better performance and reliability + for i in range(0, len(points), BATCH_SIZE): + batch = points[i : i + BATCH_SIZE] + try: + self.index.upsert(vectors=batch) + log.debug( + f"Inserted batch of {len(batch)} vectors into '{collection_name_with_prefix}'" + ) + except Exception as e: + log.error( + f"Error inserting batch into '{collection_name_with_prefix}': {e}" + ) + raise + + log.info( + f"Successfully inserted {len(items)} vectors into '{collection_name_with_prefix}'" + ) + + def upsert(self, collection_name: str, items: List[VectorItem]) -> None: + """Upsert (insert or update) vectors into a collection.""" + if not items: + log.warning("No items to upsert") + return + + collection_name_with_prefix = self._get_collection_name_with_prefix( + collection_name + ) + points = self._create_points(items, collection_name_with_prefix) + + # Upsert in batches + for i in range(0, len(points), BATCH_SIZE): + batch = points[i : i + BATCH_SIZE] + try: + self.index.upsert(vectors=batch) + log.debug( + f"Upserted batch of {len(batch)} vectors into '{collection_name_with_prefix}'" + ) + except Exception as e: + log.error( + f"Error upserting batch into '{collection_name_with_prefix}': {e}" + ) + raise + + log.info( + f"Successfully upserted {len(items)} vectors into '{collection_name_with_prefix}'" + ) + + def search( + self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int + ) -> Optional[SearchResult]: + """Search for similar vectors in a collection.""" + if not vectors or not vectors[0]: + log.warning("No vectors provided for search") + return None + + collection_name_with_prefix = self._get_collection_name_with_prefix( + collection_name + ) + + if limit is None or limit <= 0: + limit = NO_LIMIT + + try: + # Search using the first vector (assuming this is the intended behavior) + query_vector = vectors[0] + + # Perform the search + query_response = self.index.query( + vector=query_vector, + top_k=limit, + include_metadata=True, + filter={"collection_name": collection_name_with_prefix}, + ) + + if not query_response.matches: + # Return empty result if no matches + return SearchResult( + ids=[[]], + documents=[[]], + metadatas=[[]], + distances=[[]], + ) + + # Convert to GetResult format + get_result = self._result_to_get_result(query_response.matches) + + # Calculate normalized distances based on metric + distances = [ + [ + self._normalize_distance(match.score) + for match in query_response.matches + ] + ] + + return SearchResult( + ids=get_result.ids, + documents=get_result.documents, + metadatas=get_result.metadatas, + distances=distances, + ) + except Exception as e: + log.error(f"Error searching in '{collection_name_with_prefix}': {e}") + return None + + def query( + self, collection_name: str, filter: Dict, limit: Optional[int] = None + ) -> Optional[GetResult]: + """Query vectors by metadata filter.""" + collection_name_with_prefix = self._get_collection_name_with_prefix( + collection_name + ) + + if limit is None or limit <= 0: + limit = NO_LIMIT + + try: + # Create a zero vector for the dimension as Pinecone requires a vector + zero_vector = [0.0] * self.dimension + + # Combine user filter with collection_name + pinecone_filter = {"collection_name": collection_name_with_prefix} + if filter: + pinecone_filter.update(filter) + + # Perform metadata-only query + query_response = self.index.query( + vector=zero_vector, + filter=pinecone_filter, + top_k=limit, + include_metadata=True, + ) + + return self._result_to_get_result(query_response.matches) + + except Exception as e: + log.error(f"Error querying collection '{collection_name}': {e}") + return None + + def get(self, collection_name: str) -> Optional[GetResult]: + """Get all vectors in a collection.""" + collection_name_with_prefix = self._get_collection_name_with_prefix( + collection_name + ) + + try: + # Use a zero vector for fetching all entries + zero_vector = [0.0] * self.dimension + + # Add filter to only get vectors for this collection + query_response = self.index.query( + vector=zero_vector, + top_k=NO_LIMIT, + include_metadata=True, + filter={"collection_name": collection_name_with_prefix}, + ) + + return self._result_to_get_result(query_response.matches) + + except Exception as e: + log.error(f"Error getting collection '{collection_name}': {e}") + return None + + def delete( + self, + collection_name: str, + ids: Optional[List[str]] = None, + filter: Optional[Dict] = None, + ) -> None: + """Delete vectors by IDs or filter.""" + collection_name_with_prefix = self._get_collection_name_with_prefix( + collection_name + ) + + try: + if ids: + # Delete by IDs (in batches for large deletions) + for i in range(0, len(ids), BATCH_SIZE): + batch_ids = ids[i : i + BATCH_SIZE] + # Note: When deleting by ID, we can't filter by collection_name + # This is a limitation of Pinecone - be careful with ID uniqueness + self.index.delete(ids=batch_ids) + log.debug( + f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'" + ) + log.info( + f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'" + ) + + elif filter: + # Combine user filter with collection_name + pinecone_filter = {"collection_name": collection_name_with_prefix} + if filter: + pinecone_filter.update(filter) + # Delete by metadata filter + self.index.delete(filter=pinecone_filter) + log.info( + f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'" + ) + + else: + log.warning("No ids or filter provided for delete operation") + + except Exception as e: + log.error(f"Error deleting from collection '{collection_name}': {e}") + raise + + def reset(self) -> None: + """Reset the database by deleting all collections.""" + try: + self.index.delete(delete_all=True) + log.info("All vectors successfully deleted from the index.") + except Exception as e: + log.error(f"Failed to reset Pinecone index: {e}") + raise diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py index be0df6c6ac..dfe2979076 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py @@ -1,12 +1,24 @@ from typing import Optional import logging +from urllib.parse import urlparse from qdrant_client import QdrantClient as Qclient from qdrant_client.http.models import PointStruct from qdrant_client.models import models -from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult -from open_webui.config import QDRANT_URI, QDRANT_API_KEY +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + SearchResult, + GetResult, +) +from open_webui.config import ( + QDRANT_URI, + QDRANT_API_KEY, + QDRANT_ON_DISK, + QDRANT_GRPC_PORT, + QDRANT_PREFER_GRPC, +) from open_webui.env import SRC_LOG_LEVELS NO_LIMIT = 999999999 @@ -15,16 +27,34 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -class QdrantClient: +class QdrantClient(VectorDBBase): def __init__(self): self.collection_prefix = "open-webui" self.QDRANT_URI = QDRANT_URI self.QDRANT_API_KEY = QDRANT_API_KEY - self.client = ( - Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY) - if self.QDRANT_URI - else None - ) + self.QDRANT_ON_DISK = QDRANT_ON_DISK + self.PREFER_GRPC = QDRANT_PREFER_GRPC + self.GRPC_PORT = QDRANT_GRPC_PORT + + if not self.QDRANT_URI: + self.client = None + return + + # Unified handling for either scheme + parsed = urlparse(self.QDRANT_URI) + host = parsed.hostname or self.QDRANT_URI + http_port = parsed.port or 6333 # default REST port + + if self.PREFER_GRPC: + self.client = Qclient( + host=host, + port=http_port, + grpc_port=self.GRPC_PORT, + prefer_grpc=self.PREFER_GRPC, + api_key=self.QDRANT_API_KEY, + ) + else: + self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY) def _result_to_get_result(self, points) -> GetResult: ids = [] @@ -50,7 +80,9 @@ class QdrantClient: self.client.create_collection( collection_name=collection_name_with_prefix, vectors_config=models.VectorParams( - size=dimension, distance=models.Distance.COSINE + size=dimension, + distance=models.Distance.COSINE, + on_disk=self.QDRANT_ON_DISK, ), ) diff --git a/backend/open_webui/retrieval/vector/main.py b/backend/open_webui/retrieval/vector/main.py index f0cf0c0387..53f752f579 100644 --- a/backend/open_webui/retrieval/vector/main.py +++ b/backend/open_webui/retrieval/vector/main.py @@ -1,5 +1,6 @@ from pydantic import BaseModel -from typing import Optional, List, Any +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Union class VectorItem(BaseModel): @@ -17,3 +18,69 @@ class GetResult(BaseModel): class SearchResult(GetResult): distances: Optional[List[List[float | int]]] + + +class VectorDBBase(ABC): + """ + Abstract base class for all vector database backends. + + Implementations of this class provide methods for collection management, + vector insertion, deletion, similarity search, and metadata filtering. + + Any custom vector database integration must inherit from this class and + implement all abstract methods. + """ + + @abstractmethod + def has_collection(self, collection_name: str) -> bool: + """Check if the collection exists in the vector DB.""" + pass + + @abstractmethod + def delete_collection(self, collection_name: str) -> None: + """Delete a collection from the vector DB.""" + pass + + @abstractmethod + def insert(self, collection_name: str, items: List[VectorItem]) -> None: + """Insert a list of vector items into a collection.""" + pass + + @abstractmethod + def upsert(self, collection_name: str, items: List[VectorItem]) -> None: + """Insert or update vector items in a collection.""" + pass + + @abstractmethod + def search( + self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int + ) -> Optional[SearchResult]: + """Search for similar vectors in a collection.""" + pass + + @abstractmethod + def query( + self, collection_name: str, filter: Dict, limit: Optional[int] = None + ) -> Optional[GetResult]: + """Query vectors from a collection using metadata filter.""" + pass + + @abstractmethod + def get(self, collection_name: str) -> Optional[GetResult]: + """Retrieve all vectors from a collection.""" + pass + + @abstractmethod + def delete( + self, + collection_name: str, + ids: Optional[List[str]] = None, + filter: Optional[Dict] = None, + ) -> None: + """Delete vectors by ID or filter from a collection.""" + pass + + @abstractmethod + def reset(self) -> None: + """Reset the vector database by removing all collections or those matching a condition.""" + pass diff --git a/backend/open_webui/retrieval/web/external.py b/backend/open_webui/retrieval/web/external.py new file mode 100644 index 0000000000..a5c8003e47 --- /dev/null +++ b/backend/open_webui/retrieval/web/external.py @@ -0,0 +1,47 @@ +import logging +from typing import Optional, List + +import requests +from open_webui.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +def search_external( + external_url: str, + external_api_key: str, + query: str, + count: int, + filter_list: Optional[List[str]] = None, +) -> List[SearchResult]: + try: + response = requests.post( + external_url, + headers={ + "User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot", + "Authorization": f"Bearer {external_api_key}", + }, + json={ + "query": query, + "count": count, + }, + ) + response.raise_for_status() + results = response.json() + if filter_list: + results = get_filtered_results(results, filter_list) + results = [ + SearchResult( + link=result.get("link"), + title=result.get("title"), + snippet=result.get("snippet"), + ) + for result in results[:count] + ] + log.info(f"External search results: {results}") + return results + except Exception as e: + log.error(f"Error in External search: {e}") + return [] diff --git a/backend/open_webui/retrieval/web/tavily.py b/backend/open_webui/retrieval/web/tavily.py index da70aa8e7f..bfd102afa6 100644 --- a/backend/open_webui/retrieval/web/tavily.py +++ b/backend/open_webui/retrieval/web/tavily.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.retrieval.web.main import SearchResult +from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) @@ -21,18 +21,25 @@ def search_tavily( Args: api_key (str): A Tavily Search API key query (str): The query to search for + count (int): The maximum number of results to return Returns: list[SearchResult]: A list of search results """ url = "https://api.tavily.com/search" - data = {"query": query, "api_key": api_key} - response = requests.post(url, json=data) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } + data = {"query": query, "max_results": count} + response = requests.post(url, headers=headers, json=data) response.raise_for_status() json_response = response.json() - raw_search_results = json_response.get("results", []) + results = json_response.get("results", []) + if filter_list: + results = get_filtered_results(results, filter_list) return [ SearchResult( @@ -40,5 +47,5 @@ def search_tavily( title=result.get("title", ""), snippet=result.get("content"), ) - for result in raw_search_results[:count] + for result in results ] diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index 718cfe52fa..aec2a8730e 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -25,6 +25,7 @@ from langchain_community.document_loaders.firecrawl import FireCrawlLoader from langchain_community.document_loaders.base import BaseLoader from langchain_core.documents import Document from open_webui.retrieval.loaders.tavily import TavilyLoader +from open_webui.retrieval.loaders.external import ExternalLoader from open_webui.constants import ERROR_MESSAGES from open_webui.config import ( ENABLE_RAG_LOCAL_WEB_FETCH, @@ -35,6 +36,8 @@ from open_webui.config import ( FIRECRAWL_API_KEY, TAVILY_API_KEY, TAVILY_EXTRACT_DEPTH, + EXTERNAL_WEB_LOADER_URL, + EXTERNAL_WEB_LOADER_API_KEY, ) from open_webui.env import SRC_LOG_LEVELS @@ -225,7 +228,10 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): mode=self.mode, params=self.params, ) - yield from loader.lazy_load() + for document in loader.lazy_load(): + if not document.metadata.get("source"): + document.metadata["source"] = document.metadata.get("sourceURL") + yield document except Exception as e: if self.continue_on_failure: log.exception(f"Error loading {url}: {e}") @@ -245,6 +251,8 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin): params=self.params, ) async for document in loader.alazy_load(): + if not document.metadata.get("source"): + document.metadata["source"] = document.metadata.get("sourceURL") yield document except Exception as e: if self.continue_on_failure: @@ -619,6 +627,11 @@ def get_web_loader( web_loader_args["api_key"] = TAVILY_API_KEY.value web_loader_args["extract_depth"] = TAVILY_EXTRACT_DEPTH.value + if WEB_LOADER_ENGINE.value == "external": + WebLoaderClass = ExternalLoader + web_loader_args["external_url"] = EXTERNAL_WEB_LOADER_URL.value + web_loader_args["external_api_key"] = EXTERNAL_WEB_LOADER_API_KEY.value + if WebLoaderClass: web_loader = WebLoaderClass(**web_loader_args) diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 9c4d5cb9f6..b1ac1ea13a 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -27,20 +27,24 @@ from open_webui.env import ( WEBUI_AUTH_TRUSTED_NAME_HEADER, WEBUI_AUTH_COOKIE_SAME_SITE, WEBUI_AUTH_COOKIE_SECURE, + WEBUI_AUTH_SIGNOUT_REDIRECT_URL, SRC_LOG_LEVELS, ) from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import RedirectResponse, Response from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP from pydantic import BaseModel + from open_webui.utils.misc import parse_duration, validate_email_format from open_webui.utils.auth import ( + decode_token, create_api_key, create_token, get_admin_user, get_verified_user, get_current_user, get_password_hash, + get_http_authorization_cred, ) from open_webui.utils.webhook import post_webhook from open_webui.utils.access_control import get_permissions @@ -72,31 +76,13 @@ class SessionUserResponse(Token, UserResponse): async def get_session_user( request: Request, response: Response, user=Depends(get_current_user) ): - expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) - expires_at = None - if expires_delta: - expires_at = int(time.time()) + int(expires_delta.total_seconds()) - token = create_token( - data={"id": user.id}, - expires_delta=expires_delta, - ) + auth_header = request.headers.get("Authorization") + auth_token = get_http_authorization_cred(auth_header) + token = auth_token.credentials - datetime_expires_at = ( - datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc) - if expires_at - else None - ) - - # Set the cookie token - response.set_cookie( - key="token", - value=token, - expires=datetime_expires_at, - httponly=True, # Ensures the cookie is not accessible via JavaScript - samesite=WEBUI_AUTH_COOKIE_SAME_SITE, - secure=WEBUI_AUTH_COOKIE_SECURE, - ) + data = decode_token(token) + expires_at = data.get("exp") user_permissions = get_permissions( user.id, request.app.state.config.USER_PERMISSIONS @@ -288,11 +274,14 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): user = Auths.authenticate_user_by_trusted_header(email) if user: + expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) + expires_at = None + if expires_delta: + expires_at = int(time.time()) + int(expires_delta.total_seconds()) + token = create_token( data={"id": user.id}, - expires_delta=parse_duration( - request.app.state.config.JWT_EXPIRES_IN - ), + expires_delta=expires_delta, ) # Set the cookie token @@ -300,6 +289,8 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): key="token", value=token, httponly=True, # Ensures the cookie is not accessible via JavaScript + samesite=WEBUI_AUTH_COOKIE_SAME_SITE, + secure=WEBUI_AUTH_COOKIE_SECURE, ) user_permissions = get_permissions( @@ -309,6 +300,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): return { "token": token, "token_type": "Bearer", + "expires_at": expires_at, "id": user.id, "email": user.email, "name": user.name, @@ -566,6 +558,12 @@ async def signout(request: Request, response: Response): detail="Failed to sign out from the OpenID provider.", ) + if WEBUI_AUTH_SIGNOUT_REDIRECT_URL: + return RedirectResponse( + headers=response.headers, + url=WEBUI_AUTH_SIGNOUT_REDIRECT_URL, + ) + return {"status": True} diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index 5fd44ab9f0..6f00dd4d7c 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -638,8 +638,17 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/share", response_model=Optional[ChatResponse]) -async def share_chat_by_id(id: str, user=Depends(get_verified_user)): +async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)): + if not has_permission( + user.id, "chat.share", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: if chat.share_id: shared_chat = Chats.update_shared_chat_by_chat_id(chat.id) diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 8a2888d864..5907b69f4e 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -173,7 +173,8 @@ async def list_files(user=Depends(get_verified_user), content: bool = Query(True if not content: for file in files: - del file.data["content"] + if "content" in file.data: + del file.data["content"] return files @@ -214,7 +215,8 @@ async def search_files( if not content: for file in matching_files: - del file.data["content"] + if "content" in file.data: + del file.data["content"] return matching_files diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 15547afa7c..d0d95e2f4e 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -9,7 +9,7 @@ from open_webui.models.knowledge import ( KnowledgeResponse, KnowledgeUserResponse, ) -from open_webui.models.files import Files, FileModel +from open_webui.models.files import Files, FileModel, FileMetadataResponse from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.routers.retrieval import ( process_file, @@ -235,7 +235,7 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us class KnowledgeFilesResponse(KnowledgeResponse): - files: list[FileModel] + files: list[FileMetadataResponse] @router.get("/{id}", response_model=Optional[KnowledgeFilesResponse]) @@ -251,7 +251,7 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)): ): file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] - files = Files.get_files_by_ids(file_ids) + files = Files.get_file_metadatas_by_ids(file_ids) return KnowledgeFilesResponse( **knowledge.model_dump(), @@ -379,7 +379,7 @@ def add_file_to_knowledge_by_id( knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data) if knowledge: - files = Files.get_files_by_ids(file_ids) + files = Files.get_file_metadatas_by_ids(file_ids) return KnowledgeFilesResponse( **knowledge.model_dump(), @@ -456,7 +456,7 @@ def update_file_from_knowledge_by_id( data = knowledge.data or {} file_ids = data.get("file_ids", []) - files = Files.get_files_by_ids(file_ids) + files = Files.get_file_metadatas_by_ids(file_ids) return KnowledgeFilesResponse( **knowledge.model_dump(), @@ -538,7 +538,7 @@ def remove_file_from_knowledge_by_id( knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data) if knowledge: - files = Files.get_files_by_ids(file_ids) + files = Files.get_file_metadatas_by_ids(file_ids) return KnowledgeFilesResponse( **knowledge.model_dump(), @@ -734,7 +734,7 @@ def add_files_to_knowledge_batch( error_details = [f"{err.file_id}: {err.error}" for err in result.errors] return KnowledgeFilesResponse( **knowledge.model_dump(), - files=Files.get_files_by_ids(existing_file_ids), + files=Files.get_file_metadatas_by_ids(existing_file_ids), warnings={ "message": "Some files failed to process", "errors": error_details, @@ -742,5 +742,6 @@ def add_files_to_knowledge_batch( ) return KnowledgeFilesResponse( - **knowledge.model_dump(), files=Files.get_files_by_ids(existing_file_ids) + **knowledge.model_dump(), + files=Files.get_file_metadatas_by_ids(existing_file_ids), ) diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 775cd04465..b41a240a70 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -54,6 +54,7 @@ from open_webui.config import ( from open_webui.env import ( ENV, SRC_LOG_LEVELS, + AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, BYPASS_MODEL_ACCESS_CONTROL, @@ -91,6 +92,7 @@ async def send_get_request(url, key=None, user: UserModel = None): else {} ), }, + ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as response: return await response.json() except Exception as e: @@ -141,6 +143,7 @@ async def send_post_request( else {} ), }, + ssl=AIOHTTP_CLIENT_SESSION_SSL, ) r.raise_for_status() @@ -216,7 +219,8 @@ async def verify_connection( key = form_data.key async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST) + trust_env=True, + timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST), ) as session: try: async with session.get( @@ -234,6 +238,7 @@ async def verify_connection( else {} ), }, + ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: if r.status != 200: detail = f"HTTP Error: {r.status}" @@ -1006,7 +1011,7 @@ class GenerateCompletionForm(BaseModel): prompt: str suffix: Optional[str] = None images: Optional[list[str]] = None - format: Optional[str] = None + format: Optional[Union[dict, str]] = None options: Optional[dict] = None system: Optional[str] = None template: Optional[str] = None @@ -1482,7 +1487,9 @@ async def download_file_stream( timeout = aiohttp.ClientTimeout(total=600) # Set the timeout async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.get(file_url, headers=headers) as response: + async with session.get( + file_url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL + ) as response: total_size = int(response.headers.get("content-length", 0)) + current_size with open(file_path, "ab+") as file: @@ -1497,7 +1504,8 @@ async def download_file_stream( if done: file.seek(0) - hashed = calculate_sha256(file) + chunk_size = 1024 * 1024 * 2 + hashed = calculate_sha256(file, chunk_size) file.seek(0) url = f"{ollama_url}/api/blobs/sha256:{hashed}" diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 0310014cf5..02a81209c1 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -21,6 +21,7 @@ from open_webui.config import ( CACHE_DIR, ) from open_webui.env import ( + AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, ENABLE_FORWARD_USER_INFO_HEADERS, @@ -74,6 +75,7 @@ async def send_get_request(url, key=None, user: UserModel = None): else {} ), }, + ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as response: return await response.json() except Exception as e: @@ -92,20 +94,19 @@ async def cleanup_response( await session.close() -def openai_o1_o3_handler(payload): +def openai_o_series_handler(payload): """ - Handle o1, o3 specific parameters + Handle "o" series specific parameters """ if "max_tokens" in payload: - # Remove "max_tokens" from the payload + # Convert "max_tokens" to "max_completion_tokens" for all o-series models payload["max_completion_tokens"] = payload["max_tokens"] del payload["max_tokens"] - # Fix: o1 and o3 do not support the "system" role directly. - # For older models like "o1-mini" or "o1-preview", use role "user". - # For newer o1/o3 models, replace "system" with "developer". + # Handle system role conversion based on model type if payload["messages"][0]["role"] == "system": model_lower = payload["model"].lower() + # Legacy models use "user" role instead of "system" if model_lower.startswith("o1-mini") or model_lower.startswith("o1-preview"): payload["messages"][0]["role"] = "user" else: @@ -462,7 +463,8 @@ async def get_models( r = None async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST) + trust_env=True, + timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST), ) as session: try: async with session.get( @@ -481,6 +483,7 @@ async def get_models( else {} ), }, + ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: if r.status != 200: # Extract response error details if available @@ -542,7 +545,8 @@ async def verify_connection( key = form_data.key async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST) + trust_env=True, + timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST), ) as session: try: async with session.get( @@ -561,6 +565,7 @@ async def verify_connection( else {} ), }, + ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: if r.status != 200: # Extract response error details if available @@ -666,10 +671,10 @@ async def generate_chat_completion( url = request.app.state.config.OPENAI_API_BASE_URLS[idx] key = request.app.state.config.OPENAI_API_KEYS[idx] - # Fix: o1,o3 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens" - is_o1_o3 = payload["model"].lower().startswith(("o1", "o3-")) - if is_o1_o3: - payload = openai_o1_o3_handler(payload) + # Check if model is from "o" series + is_o_series = payload["model"].lower().startswith(("o1", "o3", "o4")) + if is_o_series: + payload = openai_o_series_handler(payload) elif "api.openai.com" not in url: # Remove "max_completion_tokens" from the payload for backward compatibility if "max_completion_tokens" in payload: @@ -723,6 +728,7 @@ async def generate_chat_completion( else {} ), }, + ssl=AIOHTTP_CLIENT_SESSION_SSL, ) # Check if response is SSE @@ -802,6 +808,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): else {} ), }, + ssl=AIOHTTP_CLIENT_SESSION_SSL, ) r.raise_for_status() diff --git a/backend/open_webui/routers/pipelines.py b/backend/open_webui/routers/pipelines.py index 10c8e9b2ec..f140025026 100644 --- a/backend/open_webui/routers/pipelines.py +++ b/backend/open_webui/routers/pipelines.py @@ -66,7 +66,7 @@ async def process_pipeline_inlet_filter(request, payload, user, models): if "pipeline" in model: sorted_filters.append(model) - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(trust_env=True) as session: for filter in sorted_filters: urlIdx = filter.get("urlIdx") if urlIdx is None: @@ -115,7 +115,7 @@ async def process_pipeline_outlet_filter(request, payload, user, models): if "pipeline" in model: sorted_filters = [model] + sorted_filters - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(trust_env=True) as session: for filter in sorted_filters: urlIdx = filter.get("urlIdx") if urlIdx is None: diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 13f012483d..a582bd9bad 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -61,6 +61,7 @@ from open_webui.retrieval.web.bing import search_bing from open_webui.retrieval.web.exa import search_exa from open_webui.retrieval.web.perplexity import search_perplexity from open_webui.retrieval.web.sougou import search_sougou +from open_webui.retrieval.web.external import search_external from open_webui.retrieval.utils import ( get_embedding_function, @@ -90,7 +91,12 @@ from open_webui.env import ( SRC_LOG_LEVELS, DEVICE_TYPE, DOCKER, + SENTENCE_TRANSFORMERS_BACKEND, + SENTENCE_TRANSFORMERS_MODEL_KWARGS, + SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND, + SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS, ) + from open_webui.constants import ERROR_MESSAGES log = logging.getLogger(__name__) @@ -117,6 +123,8 @@ def get_ef( get_model_path(embedding_model, auto_update), device=DEVICE_TYPE, trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, + backend=SENTENCE_TRANSFORMERS_BACKEND, + model_kwargs=SENTENCE_TRANSFORMERS_MODEL_KWARGS, ) except Exception as e: log.debug(f"Error loading SentenceTransformer: {e}") @@ -150,6 +158,8 @@ def get_rf( get_model_path(reranking_model, auto_update), device=DEVICE_TYPE, trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND, + model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS, ) except Exception as e: log.error(f"CrossEncoder: {e}") @@ -418,6 +428,10 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY, "FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL, "TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH, + "EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL, + "EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY, + "EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL, + "EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY, "YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, "YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, "YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION, @@ -463,6 +477,10 @@ class WebConfig(BaseModel): FIRECRAWL_API_KEY: Optional[str] = None FIRECRAWL_API_BASE_URL: Optional[str] = None TAVILY_EXTRACT_DEPTH: Optional[str] = None + EXTERNAL_WEB_SEARCH_URL: Optional[str] = None + EXTERNAL_WEB_SEARCH_API_KEY: Optional[str] = None + EXTERNAL_WEB_LOADER_URL: Optional[str] = None + EXTERNAL_WEB_LOADER_API_KEY: Optional[str] = None YOUTUBE_LOADER_LANGUAGE: Optional[List[str]] = None YOUTUBE_LOADER_PROXY_URL: Optional[str] = None YOUTUBE_LOADER_TRANSLATION: Optional[str] = None @@ -697,6 +715,18 @@ async def update_rag_config( request.app.state.config.FIRECRAWL_API_BASE_URL = ( form_data.web.FIRECRAWL_API_BASE_URL ) + request.app.state.config.EXTERNAL_WEB_SEARCH_URL = ( + form_data.web.EXTERNAL_WEB_SEARCH_URL + ) + request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = ( + form_data.web.EXTERNAL_WEB_SEARCH_API_KEY + ) + request.app.state.config.EXTERNAL_WEB_LOADER_URL = ( + form_data.web.EXTERNAL_WEB_LOADER_URL + ) + request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY = ( + form_data.web.EXTERNAL_WEB_LOADER_API_KEY + ) request.app.state.config.TAVILY_EXTRACT_DEPTH = ( form_data.web.TAVILY_EXTRACT_DEPTH ) @@ -778,6 +808,10 @@ async def update_rag_config( "FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY, "FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL, "TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH, + "EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL, + "EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY, + "EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL, + "EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY, "YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, "YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, "YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION, @@ -1465,6 +1499,14 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: raise Exception( "No SOUGOU_API_SID or SOUGOU_API_SK found in environment variables" ) + elif engine == "external": + return search_external( + request.app.state.config.EXTERNAL_WEB_SEARCH_URL, + request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY, + query, + request.app.state.config.WEB_SEARCH_RESULT_COUNT, + request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, + ) else: raise Exception("No search engine API key found in environment variables") @@ -1477,8 +1519,11 @@ async def process_web_search( logging.info( f"trying to web search with {request.app.state.config.WEB_SEARCH_ENGINE, form_data.query}" ) - web_results = search_web( - request, request.app.state.config.WEB_SEARCH_ENGINE, form_data.query + web_results = await run_in_threadpool( + search_web, + request, + request.app.state.config.WEB_SEARCH_ENGINE, + form_data.query, ) except Exception as e: log.exception(e) @@ -1500,8 +1545,8 @@ async def process_web_search( ) docs = await loader.aload() urls = [ - doc.metadata["source"] for doc in docs - ] # only keep URLs which could be retrieved + doc.metadata.get("source") for doc in docs if doc.metadata.get("source") + ] # only keep URLs if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: return { @@ -1521,19 +1566,22 @@ async def process_web_search( collection_names = [] for doc_idx, doc in enumerate(docs): if doc and doc.page_content: - collection_name = f"web-search-{calculate_sha256_string(form_data.query + '-' + urls[doc_idx])}"[ - :63 - ] + try: + collection_name = f"web-search-{calculate_sha256_string(form_data.query + '-' + urls[doc_idx])}"[ + :63 + ] - collection_names.append(collection_name) - await run_in_threadpool( - save_docs_to_vector_db, - request, - [doc], - collection_name, - overwrite=True, - user=user, - ) + collection_names.append(collection_name) + await run_in_threadpool( + save_docs_to_vector_db, + request, + [doc], + collection_name, + overwrite=True, + user=user, + ) + except Exception as e: + log.debug(f"error saving doc {doc_idx}: {e}") return { "status": True, diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index a9ac34e2fb..29638199e0 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -88,6 +88,8 @@ class ChatPermissions(BaseModel): file_upload: bool = True delete: bool = True edit: bool = True + share: bool = True + export: bool = True stt: bool = True tts: bool = True call: bool = True @@ -288,6 +290,21 @@ async def update_user_by_id( form_data: UserUpdateForm, session_user=Depends(get_admin_user), ): + # Prevent modification of the primary admin user by other admins + try: + first_user = Users.get_first_user() + if first_user and user_id == first_user.id and session_user.id != user_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACTION_PROHIBITED, + ) + except Exception as e: + log.error(f"Error checking primary admin status: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Could not verify primary admin status.", + ) + user = Users.get_user_by_id(user_id) if user: @@ -335,6 +352,21 @@ async def update_user_by_id( @router.delete("/{user_id}", response_model=bool) async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): + # Prevent deletion of the primary admin user + try: + first_user = Users.get_first_user() + if first_user and user_id == first_user.id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACTION_PROHIBITED, + ) + except Exception as e: + log.error(f"Error checking primary admin status: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Could not verify primary admin status.", + ) + if user.id != user_id: result = Auths.delete_auth_by_id(user_id) @@ -346,6 +378,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): detail=ERROR_MESSAGES.DELETE_USER_ERROR, ) + # Prevent self-deletion raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACTION_PROHIBITED, diff --git a/backend/open_webui/utils/audit.py b/backend/open_webui/utils/audit.py index 2d7ceabcb8..8193907d27 100644 --- a/backend/open_webui/utils/audit.py +++ b/backend/open_webui/utils/audit.py @@ -37,7 +37,7 @@ if TYPE_CHECKING: class AuditLogEntry: # `Metadata` audit level properties id: str - user: dict[str, Any] + user: Optional[dict[str, Any]] audit_level: str verb: str request_uri: str @@ -190,21 +190,40 @@ class AuditLoggingMiddleware: finally: await self._log_audit_entry(request, context) - async def _get_authenticated_user(self, request: Request) -> UserModel: - + async def _get_authenticated_user(self, request: Request) -> Optional[UserModel]: auth_header = request.headers.get("Authorization") - assert auth_header - user = get_current_user(request, None, get_http_authorization_cred(auth_header)) - return user + try: + user = get_current_user( + request, None, get_http_authorization_cred(auth_header) + ) + return user + except Exception as e: + logger.debug(f"Failed to get authenticated user: {str(e)}") + + return None def _should_skip_auditing(self, request: Request) -> bool: if ( request.method not in {"POST", "PUT", "PATCH", "DELETE"} or AUDIT_LOG_LEVEL == "NONE" - or not request.headers.get("authorization") ): return True + + ALWAYS_LOG_ENDPOINTS = { + "/api/v1/auths/signin", + "/api/v1/auths/signout", + "/api/v1/auths/signup", + } + path = request.url.path.lower() + for endpoint in ALWAYS_LOG_ENDPOINTS: + if path.startswith(endpoint): + return False # Do NOT skip logging for auth endpoints + + # Skip logging if the request is not authenticated + if not request.headers.get("authorization"): + return True + # match either /api//...(for the endpoint /api/chat case) or /api/v1//... pattern = re.compile( r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b" @@ -231,17 +250,32 @@ class AuditLoggingMiddleware: try: user = await self._get_authenticated_user(request) + user = ( + user.model_dump(include={"id", "name", "email", "role"}) if user else {} + ) + + request_body = context.request_body.decode("utf-8", errors="replace") + response_body = context.response_body.decode("utf-8", errors="replace") + + # Redact sensitive information + if "password" in request_body: + request_body = re.sub( + r'"password":\s*"(.*?)"', + '"password": "********"', + request_body, + ) + entry = AuditLogEntry( id=str(uuid.uuid4()), - user=user.model_dump(include={"id", "name", "email", "role"}), + user=user, audit_level=self.audit_level.value, verb=request.method, request_uri=str(request.url), response_status_code=context.metadata.get("response_status_code", None), source_ip=request.client.host if request.client else None, user_agent=request.headers.get("user-agent"), - request_object=context.request_body.decode("utf-8", errors="replace"), - response_object=context.response_body.decode("utf-8", errors="replace"), + request_object=request_body, + response_object=response_body, ) self.audit_logger.write(entry) diff --git a/backend/open_webui/utils/code_interpreter.py b/backend/open_webui/utils/code_interpreter.py index 312baff241..1ad5ee93c9 100644 --- a/backend/open_webui/utils/code_interpreter.py +++ b/backend/open_webui/utils/code_interpreter.py @@ -50,7 +50,7 @@ class JupyterCodeExecuter: self.password = password self.timeout = timeout self.kernel_id = "" - self.session = aiohttp.ClientSession(base_url=self.base_url) + self.session = aiohttp.ClientSession(trust_env=True, base_url=self.base_url) self.params = {} self.result = ResultModel() diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 4070bc697f..df771764c5 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -888,16 +888,20 @@ async def process_chat_payload(request, form_data, user, metadata, model): # If context is not empty, insert it into the messages if len(sources) > 0: context_string = "" - citated_file_idx = {} - for _, source in enumerate(sources, 1): + citation_idx = {} + for source in sources: if "document" in source: for doc_context, doc_meta in zip( source["document"], source["metadata"] ): - file_id = doc_meta.get("file_id") - if file_id not in citated_file_idx: - citated_file_idx[file_id] = len(citated_file_idx) + 1 - context_string += f'{doc_context}\n' + citation_id = ( + doc_meta.get("source", None) + or source.get("source", {}).get("id", None) + or "N/A" + ) + if citation_id not in citation_idx: + citation_idx[citation_id] = len(citation_idx) + 1 + context_string += f'{doc_context}\n' context_string = context_string.strip() prompt = get_last_user_message(form_data["messages"]) @@ -1667,6 +1671,15 @@ async def process_chat_response( if current_response_tool_call is None: # Add the new tool call + delta_tool_call.setdefault( + "function", {} + ) + delta_tool_call[ + "function" + ].setdefault("name", "") + delta_tool_call[ + "function" + ].setdefault("arguments", "") response_tool_calls.append( delta_tool_call ) diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 9ebe0e6dcb..d526382c12 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -15,7 +15,7 @@ from starlette.responses import RedirectResponse from open_webui.models.auths import Auths from open_webui.models.users import Users -from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm +from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm from open_webui.config import ( DEFAULT_USER_ROLE, ENABLE_OAUTH_SIGNUP, @@ -23,6 +23,7 @@ from open_webui.config import ( OAUTH_PROVIDERS, ENABLE_OAUTH_ROLE_MANAGEMENT, ENABLE_OAUTH_GROUP_MANAGEMENT, + ENABLE_OAUTH_GROUP_CREATION, OAUTH_ROLES_CLAIM, OAUTH_GROUPS_CLAIM, OAUTH_EMAIL_CLAIM, @@ -57,6 +58,7 @@ auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT = ENABLE_OAUTH_GROUP_MANAGEMENT +auth_manager_config.ENABLE_OAUTH_GROUP_CREATION = ENABLE_OAUTH_GROUP_CREATION auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM auth_manager_config.OAUTH_GROUPS_CLAIM = OAUTH_GROUPS_CLAIM auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM @@ -152,6 +154,51 @@ class OAuthManager: user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id) all_available_groups: list[GroupModel] = Groups.get_groups() + # Create groups if they don't exist and creation is enabled + if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION: + log.debug("Checking for missing groups to create...") + all_group_names = {g.name for g in all_available_groups} + groups_created = False + # Determine creator ID: Prefer admin, fallback to current user if no admin exists + admin_user = Users.get_admin_user() + creator_id = admin_user.id if admin_user else user.id + log.debug(f"Using creator ID {creator_id} for potential group creation.") + + for group_name in user_oauth_groups: + if group_name not in all_group_names: + log.info( + f"Group '{group_name}' not found via OAuth claim. Creating group..." + ) + try: + new_group_form = GroupForm( + name=group_name, + description=f"Group '{group_name}' created automatically via OAuth.", + permissions=default_permissions, # Use default permissions from function args + user_ids=[], # Start with no users, user will be added later by subsequent logic + ) + # Use determined creator ID (admin or fallback to current user) + created_group = Groups.insert_new_group( + creator_id, new_group_form + ) + if created_group: + log.info( + f"Successfully created group '{group_name}' with ID {created_group.id} using creator ID {creator_id}" + ) + groups_created = True + # Add to local set to prevent duplicate creation attempts in this run + all_group_names.add(group_name) + else: + log.error( + f"Failed to create group '{group_name}' via OAuth." + ) + except Exception as e: + log.error(f"Error creating group '{group_name}' via OAuth: {e}") + + # Refresh the list of all available groups if any were created + if groups_created: + all_available_groups = Groups.get_groups() + log.debug("Refreshed list of all available groups after creation.") + log.debug(f"Oauth Groups claim: {oauth_claim}") log.debug(f"User oauth groups: {user_oauth_groups}") log.debug(f"User's current groups: {[g.name for g in user_current_groups]}") @@ -257,7 +304,7 @@ class OAuthManager: try: access_token = token.get("access_token") headers = {"Authorization": f"Bearer {access_token}"} - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(trust_env=True) as session: async with session.get( "https://api.github.com/user/emails", headers=headers ) as resp: @@ -339,7 +386,7 @@ class OAuthManager: get_kwargs["headers"] = { "Authorization": f"Bearer {access_token}", } - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(trust_env=True) as session: async with session.get( picture_url, **get_kwargs ) as resp: diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index b5d916e1de..ad520ed09a 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -36,7 +36,10 @@ from langchain_core.utils.function_calling import ( from open_webui.models.tools import Tools from open_webui.models.users import UserModel from open_webui.utils.plugin import load_tool_module_by_id -from open_webui.env import AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA +from open_webui.env import ( + AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA, + AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL, +) import copy @@ -371,51 +374,64 @@ def convert_openapi_to_tool_payload(openapi_spec): for path, methods in openapi_spec.get("paths", {}).items(): for method, operation in methods.items(): - tool = { - "type": "function", - "name": operation.get("operationId"), - "description": operation.get( - "description", operation.get("summary", "No description available.") - ), - "parameters": {"type": "object", "properties": {}, "required": []}, - } - - # Extract path and query parameters - for param in operation.get("parameters", []): - param_name = param["name"] - param_schema = param.get("schema", {}) - tool["parameters"]["properties"][param_name] = { - "type": param_schema.get("type"), - "description": param_schema.get("description", ""), + if operation.get("operationId"): + tool = { + "type": "function", + "name": operation.get("operationId"), + "description": operation.get( + "description", + operation.get("summary", "No description available."), + ), + "parameters": {"type": "object", "properties": {}, "required": []}, } - if param.get("required"): - tool["parameters"]["required"].append(param_name) - # Extract and resolve requestBody if available - request_body = operation.get("requestBody") - if request_body: - content = request_body.get("content", {}) - json_schema = content.get("application/json", {}).get("schema") - if json_schema: - resolved_schema = resolve_schema( - json_schema, openapi_spec.get("components", {}) - ) - - if resolved_schema.get("properties"): - tool["parameters"]["properties"].update( - resolved_schema["properties"] + # Extract path and query parameters + for param in operation.get("parameters", []): + param_name = param["name"] + param_schema = param.get("schema", {}) + description = param_schema.get("description", "") + if not description: + description = param.get("description") or "" + if param_schema.get("enum") and isinstance( + param_schema.get("enum"), list + ): + description += ( + f". Possible values: {', '.join(param_schema.get('enum'))}" ) - if "required" in resolved_schema: - tool["parameters"]["required"] = list( - set( - tool["parameters"]["required"] - + resolved_schema["required"] - ) - ) - elif resolved_schema.get("type") == "array": - tool["parameters"] = resolved_schema # special case for array + tool["parameters"]["properties"][param_name] = { + "type": param_schema.get("type"), + "description": description, + } + if param.get("required"): + tool["parameters"]["required"].append(param_name) - tool_payload.append(tool) + # Extract and resolve requestBody if available + request_body = operation.get("requestBody") + if request_body: + content = request_body.get("content", {}) + json_schema = content.get("application/json", {}).get("schema") + if json_schema: + resolved_schema = resolve_schema( + json_schema, openapi_spec.get("components", {}) + ) + + if resolved_schema.get("properties"): + tool["parameters"]["properties"].update( + resolved_schema["properties"] + ) + if "required" in resolved_schema: + tool["parameters"]["required"] = list( + set( + tool["parameters"]["required"] + + resolved_schema["required"] + ) + ) + elif resolved_schema.get("type") == "array": + tool["parameters"] = ( + resolved_schema # special case for array + ) + + tool_payload.append(tool) return tool_payload @@ -431,8 +447,10 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]: error = None try: timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA) - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.get(url, headers=headers) as response: + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: + async with session.get( + url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL + ) as response: if response.status != 200: error_body = await response.json() raise Exception(error_body) @@ -573,19 +591,26 @@ async def execute_tool_server( if token: headers["Authorization"] = f"Bearer {token}" - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(trust_env=True) as session: request_method = getattr(session, http_method.lower()) if http_method in ["post", "put", "patch"]: async with request_method( - final_url, json=body_params, headers=headers + final_url, + json=body_params, + headers=headers, + ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL, ) as response: if response.status >= 400: text = await response.text() raise Exception(f"HTTP error {response.status}: {text}") return await response.json() else: - async with request_method(final_url, headers=headers) as response: + async with request_method( + final_url, + headers=headers, + ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL, + ) as response: if response.status >= 400: text = await response.text() raise Exception(f"HTTP error {response.status}: {text}") diff --git a/backend/requirements.txt b/backend/requirements.txt index f0cf262ee6..5ba6a84d4e 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -50,10 +50,10 @@ qdrant-client~=1.12.0 opensearch-py==2.8.0 playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml elasticsearch==8.17.1 - +pinecone==6.0.2 transformers -sentence-transformers==3.3.1 +sentence-transformers==4.1.0 accelerate colbert-ai==0.2.21 einops==0.8.1 diff --git a/package-lock.json b/package-lock.json index c4f6948bd4..7cabf09741 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.6.5", + "version": "0.6.6", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.6.5", + "version": "0.6.6", "dependencies": { "@azure/msal-browser": "^4.5.0", "@codemirror/lang-javascript": "^6.2.2", diff --git a/package.json b/package.json index 0a982196e4..8e3feab5a3 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.6.5", + "version": "0.6.6", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", diff --git a/pyproject.toml b/pyproject.toml index 8a48c90fac..7760b51bc9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,9 +58,10 @@ dependencies = [ "opensearch-py==2.8.0", "playwright==1.49.1", "elasticsearch==8.17.1", + "pinecone==6.0.2", "transformers", - "sentence-transformers==3.3.1", + "sentence-transformers==4.1.0", "accelerate", "colbert-ai==0.2.21", "einops==0.8.1", diff --git a/src/lib/components/admin/Settings/WebSearch.svelte b/src/lib/components/admin/Settings/WebSearch.svelte index d9771f8354..96b8874056 100644 --- a/src/lib/components/admin/Settings/WebSearch.svelte +++ b/src/lib/components/admin/Settings/WebSearch.svelte @@ -30,9 +30,10 @@ 'bing', 'exa', 'perplexity', - 'sougou' + 'sougou', + 'external' ]; - let webLoaderEngines = ['playwright', 'firecrawl', 'tavily']; + let webLoaderEngines = ['playwright', 'firecrawl', 'tavily', 'external']; let webConfig = null; @@ -431,6 +432,37 @@ /> + {:else if webConfig.WEB_SEARCH_ENGINE === 'external'} +
+
+
+ {$i18n.t('External Web Search URL')} +
+ +
+
+ +
+
+
+ +
+
+ {$i18n.t('External Web Search API Key')} +
+ + +
+
{/if} {/if} @@ -652,6 +684,37 @@ {/if} + {:else if webConfig.WEB_LOADER_ENGINE === 'external'} +
+
+
+ {$i18n.t('External Web Loader URL')} +
+ +
+
+ +
+
+
+ +
+
+ {$i18n.t('External Web Loader API Key')} +
+ + +
+
{/if}
diff --git a/src/lib/components/admin/Users/Groups.svelte b/src/lib/components/admin/Users/Groups.svelte index dce8423e5d..36c1014ba2 100644 --- a/src/lib/components/admin/Users/Groups.svelte +++ b/src/lib/components/admin/Users/Groups.svelte @@ -63,6 +63,8 @@ file_upload: true, delete: true, edit: true, + share: true, + export: true, stt: true, tts: true, call: true, diff --git a/src/lib/components/admin/Users/Groups/Permissions.svelte b/src/lib/components/admin/Users/Groups/Permissions.svelte index c7a1308a5b..9edf20ca0c 100644 --- a/src/lib/components/admin/Users/Groups/Permissions.svelte +++ b/src/lib/components/admin/Users/Groups/Permissions.svelte @@ -24,6 +24,8 @@ file_upload: true, delete: true, edit: true, + share: true, + export: true, stt: true, tts: true, call: true, @@ -276,6 +278,22 @@
+
+
+ {$i18n.t('Allow Chat Share')} +
+ + +
+ +
+
+ {$i18n.t('Allow Chat Export')} +
+ + +
+
{$i18n.t('Allow Speech to Text')} diff --git a/src/lib/components/channel/Channel.svelte b/src/lib/components/channel/Channel.svelte index ce2aa54f1c..a0de69570e 100644 --- a/src/lib/components/channel/Channel.svelte +++ b/src/lib/components/channel/Channel.svelte @@ -262,7 +262,7 @@ {#if threadId !== null} { + onClose={() => { threadId = null; }} > diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index fb9faa2470..b0f806452a 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -236,9 +236,11 @@ await tick(); await tick(); - const messageElement = document.getElementById(`message-${message.id}`); - if (messageElement) { - messageElement.scrollIntoView({ behavior: 'smooth' }); + if ($settings?.scrollOnBranchChange ?? true) { + const messageElement = document.getElementById(`message-${message.id}`); + if (messageElement) { + messageElement.scrollIntoView({ behavior: 'smooth' }); + } } await tick(); diff --git a/src/lib/components/chat/ChatControls.svelte b/src/lib/components/chat/ChatControls.svelte index 92c7e4d8d0..64fd8d92d3 100644 --- a/src/lib/components/chat/ChatControls.svelte +++ b/src/lib/components/chat/ChatControls.svelte @@ -140,7 +140,7 @@ {#if $showControls} { + onClose={() => { showControls.set(false); }} > diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index ca6487cf58..d31861459a 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -395,39 +395,37 @@
- {#if atSelectedModel !== undefined || selectedToolIds.length > 0 || webSearchEnabled || ($settings?.webSearch ?? false) === 'always' || imageGenerationEnabled || codeInterpreterEnabled} + {#if atSelectedModel !== undefined}
- {#if atSelectedModel !== undefined} -
-
- model profile model.id === atSelectedModel.id)?.info?.meta - ?.profile_image_url ?? - ($i18n.language === 'dg-DG' - ? `/doge.png` - : `${WEBUI_BASE_URL}/static/favicon.png`)} - /> -
- Talking to {atSelectedModel.name} -
-
-
- +
+
+ model profile model.id === atSelectedModel.id)?.info?.meta + ?.profile_image_url ?? + ($i18n.language === 'dg-DG' + ? `/doge.png` + : `${WEBUI_BASE_URL}/static/favicon.png`)} + /> +
+ Talking to {atSelectedModel.name}
- {/if} +
+ +
+
{/if} @@ -1063,9 +1061,9 @@ ); } }} - uploadOneDriveHandler={async () => { + uploadOneDriveHandler={async (authorityType) => { try { - const fileData = await pickAndDownloadFile(); + const fileData = await pickAndDownloadFile(authorityType); if (fileData) { const file = new File([fileData.blob], fileData.name, { type: fileData.blob.type || 'application/octet-stream' diff --git a/src/lib/components/chat/MessageInput/InputMenu.svelte b/src/lib/components/chat/MessageInput/InputMenu.svelte index 27fe2cde29..7f269ef3e8 100644 --- a/src/lib/components/chat/MessageInput/InputMenu.svelte +++ b/src/lib/components/chat/MessageInput/InputMenu.svelte @@ -229,94 +229,119 @@ {/if} {#if $config?.features?.enable_onedrive_integration} - { - uploadOneDriveHandler(); - }} - > - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
{$i18n.t('Microsoft OneDrive')}
+ + + { + uploadOneDriveHandler('personal'); + }} > - - - - - - - - - - - - - - - - - - - - - - - - - - - -
{$i18n.t('OneDrive')}
-
+
{$i18n.t('Microsoft OneDrive (personal)')}
+
+ { + uploadOneDriveHandler('organizations'); + }} + > +
+
{$i18n.t('Microsoft OneDrive (work/school)')}
+
Includes SharePoint
+
+
+ + {/if}
diff --git a/src/lib/components/chat/Messages/Citations.svelte b/src/lib/components/chat/Messages/Citations.svelte index 8c2fbf799c..7177f27f25 100644 --- a/src/lib/components/chat/Messages/Citations.svelte +++ b/src/lib/components/chat/Messages/Citations.svelte @@ -83,6 +83,7 @@ }); return acc; }, []); + console.log('citations', citations); showRelevance = calculateShowRelevance(citations); showPercentage = shouldShowPercentage(citations); diff --git a/src/lib/components/chat/Messages/CitationsModal.svelte b/src/lib/components/chat/Messages/CitationsModal.svelte index 174d80c4fb..35a5594f17 100644 --- a/src/lib/components/chat/Messages/CitationsModal.svelte +++ b/src/lib/components/chat/Messages/CitationsModal.svelte @@ -139,13 +139,16 @@ {percentage.toFixed(2)}% {/if} + + {#if typeof document?.distance === 'number'} + + ({(document?.distance ?? 0).toFixed(4)}) + + {/if} + {:else if typeof document?.distance === 'number'} ({(document?.distance ?? 0).toFixed(4)}) - {:else} - - {(document?.distance ?? 0).toFixed(4)} - {/if}
diff --git a/src/lib/components/chat/Messages/ContentRenderer.svelte b/src/lib/components/chat/Messages/ContentRenderer.svelte index d3a5812f7d..74487255d5 100644 --- a/src/lib/components/chat/Messages/ContentRenderer.svelte +++ b/src/lib/components/chat/Messages/ContentRenderer.svelte @@ -154,11 +154,11 @@ }, [])} {onSourceClick} {onTaskClick} - on:update={(e) => { - dispatch('update', e.detail); + onUpdate={(value) => { + dispatch('update', value); }} - on:code={(e) => { - const { lang, code } = e.detail; + onCode={(value) => { + const { lang, code } = value; if ( ($settings?.detectArtifacts ?? true) && diff --git a/src/lib/components/chat/Messages/Markdown.svelte b/src/lib/components/chat/Messages/Markdown.svelte index 472e53e597..a014500ef3 100644 --- a/src/lib/components/chat/Messages/Markdown.svelte +++ b/src/lib/components/chat/Messages/Markdown.svelte @@ -7,9 +7,6 @@ import markedKatexExtension from '$lib/utils/marked/katex-extension'; import MarkdownTokens from './Markdown/MarkdownTokens.svelte'; - import { createEventDispatcher } from 'svelte'; - - const dispatch = createEventDispatcher(); export let id = ''; export let content; @@ -18,6 +15,9 @@ export let sourceIds = []; + export let onUpdate = () => {}; + export let onCode = () => {}; + export let onSourceClick = () => {}; export let onTaskClick = () => {}; @@ -40,17 +40,5 @@ {#key id} - { - dispatch('update', e.detail); - }} - on:code={(e) => { - dispatch('code', e.detail); - }} - /> + {/key} diff --git a/src/lib/components/chat/Messages/Markdown/HTMLToken.svelte b/src/lib/components/chat/Messages/Markdown/HTMLToken.svelte new file mode 100644 index 0000000000..3b81c225e9 --- /dev/null +++ b/src/lib/components/chat/Messages/Markdown/HTMLToken.svelte @@ -0,0 +1,49 @@ + + +{#if token.type === 'html'} + {#if html && html.includes(']*src="https:\/\/www\.youtube\.com\/embed\/([a-zA-Z0-9_-]{11})(?:\?[^"]*)?"[^>]*><\/iframe>/)} + {@const match = token.text.match( + /]*src="https:\/\/www\.youtube\.com\/embed\/([a-zA-Z0-9_-]{11})(?:\?[^"]*)?"[^>]*><\/iframe>/ + )} + {@const ytId = match && match[1]} + {#if ytId} + + {/if} + {:else if token.text.includes(`