diff --git a/.github/workflows/release-pypi.yml b/.github/workflows/release-pypi.yml index 8a2e3438a6..5bf3973b33 100644 --- a/.github/workflows/release-pypi.yml +++ b/.github/workflows/release-pypi.yml @@ -19,7 +19,7 @@ jobs: uses: actions/checkout@v4 - uses: actions/setup-node@v4 with: - node-version: 18 + node-version: 22 - uses: actions/setup-python@v5 with: python-version: 3.11 diff --git a/CHANGELOG.md b/CHANGELOG.md index 86ff57384d..29715f6f34 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,87 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.5.18] - 2025-02-27 + +### Fixed + +- **🌐 Open WebUI Now Works Over LAN in Insecure Context**: Resolved an issue preventing Open WebUI from functioning when accessed over a local network in an insecure context, ensuring seamless connectivity. +- **🔄 UI Now Reflects Deleted Connections Instantly**: Fixed an issue where deleting a connection did not update the UI in real time, ensuring accurate system state visibility. +- **🛠️ Models Now Display Correctly with ENABLE_FORWARD_USER_INFO_HEADERS**: Addressed a bug where models were not visible when ENABLE_FORWARD_USER_INFO_HEADERS was set, restoring proper model listing. + +## [0.5.17] - 2025-02-27 + +### Added + +- **🚀 Instant Document Upload with Bypass Embedding & Retrieval**: Admins can now enable "Bypass Embedding & Retrieval" in Admin Settings > Documents, significantly speeding up document uploads and ensuring full document context is retained without chunking. +- **🔎 "Stream" Hook for Real-Time Filtering**: The new "stream" hook allows dynamic real-time message filtering. Learn more in our documentation (https://docs.openwebui.com/features/plugin/functions/filter). +- **☁️ OneDrive Integration**: Early support for OneDrive storage integration has been introduced, expanding file import options. +- **📈 Enhanced Logging with Loguru**: Backend logging has been improved with Loguru, making debugging and issue tracking far more efficient. +- **⚙️ General Stability Enhancements**: Backend and frontend refactoring improves performance, ensuring a smoother and more reliable user experience. +- **🌍 Updated Translations**: Refined multilingual support for better localization and accuracy across various languages. + +### Fixed + +- **🔄 Reliable Model Imports from the Community Platform**: Resolved import failures, allowing seamless integration of community-shared models without errors. +- **📊 OpenAI Usage Statistics Restored**: Fixed an issue where OpenAI usage metrics were not displaying correctly, ensuring accurate tracking of usage data. +- **🗂️ Deduplication for Retrieved Documents**: Documents retrieved during searches are now intelligently deduplicated, meaning no more redundant results—helping to keep information concise and relevant. + +### Changed + +- **📝 "Full Context Mode" Renamed for Clarity**: The "Full Context Mode" toggle in Web Search settings is now labeled "Bypass Embedding & Retrieval" for consistency across the UI. + +## [0.5.16] - 2025-02-20 + +### Fixed + +- **🔍 Web Search Retrieval Restored**: Resolved a critical issue that broke web search retrieval by reverting deduplication changes, ensuring complete and accurate search results once again. + +## [0.5.15] - 2025-02-20 + +### Added + +- **📄 Full Context Mode for Local Document Search (RAG)**: Toggle full context mode from Admin Settings > Documents to inject entire document content into context, improving accuracy for models with large context windows—ideal for deep context understanding. +- **🌍 Smarter Web Search with Agentic Workflows**: Web searches now intelligently gather and refine multiple relevant terms, similar to RAG handling, delivering significantly better search results for more accurate information retrieval. +- **🔎 Experimental Playwright Support for Web Loader**: Web content retrieval is taken to the next level with Playwright-powered scraping for enhanced accuracy in extracted web data. +- **☁️ Experimental Azure Storage Provider**: Early-stage support for Azure Storage allows more cloud storage flexibility directly within Open WebUI. +- **📊 Improved Jupyter Code Execution with Plots**: Interactive coding now properly displays inline plots, making data visualization more seamless inside chat interactions. +- **⏳ Adjustable Execution Timeout for Jupyter Interpreter**: Customize execution timeout (default: 60s) for Jupyter-based code execution, allowing longer or more constrained execution based on your needs. +- **▶️ "Running..." Indicator for Jupyter Code Execution**: A visual indicator now appears while code execution is in progress, providing real-time status updates on ongoing computations. +- **⚙️ General Backend & Frontend Stability Enhancements**: Extensive refactoring improves reliability, performance, and overall user experience for a more seamless Open WebUI. +- **🌍 Translation Updates**: Various international translation refinements ensure better localization and a more natural user interface experience. + +### Fixed + +- **📱 Mobile Hover Issue Resolved**: Users can now edit responses smoothly on mobile without interference, fixing a longstanding hover issue. +- **🔄 Temporary Chat Message Duplication Fixed**: Eliminated buggy behavior where messages were being unnecessarily repeated in temporary chat mode, ensuring a smooth and consistent conversation flow. + +## [0.5.14] - 2025-02-17 + +### Fixed + +- **🔧 Critical Import Error Resolved**: Fixed a circular import issue preventing 'override_static' from being correctly imported in 'open_webui.config', ensuring smooth system initialization and stability. + +## [0.5.13] - 2025-02-17 + +### Added + +- **🌐 Full Context Mode for Web Search**: Enable highly accurate web searches by utilizing full context mode—ideal for models with large context windows, ensuring more precise and insightful results. +- **⚡ Optimized Asynchronous Web Search**: Web searches now load significantly faster with optimized async support, providing users with quicker, more efficient information retrieval. +- **🔄 Auto Text Direction for RTL Languages**: Automatic text alignment based on language input, ensuring seamless conversation flow for Arabic, Hebrew, and other right-to-left scripts. +- **🚀 Jupyter Notebook Support for Code Execution**: The "Run" button in code blocks can now use Jupyter for execution, offering a powerful, dynamic coding experience directly in the chat. +- **🗑️ Message Delete Confirmation Dialog**: Prevent accidental deletions with a new confirmation prompt before removing messages, adding an additional layer of security to your chat history. +- **📥 Download Button for SVG Diagrams**: SVG diagrams generated within chat can now be downloaded instantly, making it easier to save and share complex visual data. +- **✨ General UI/UX Improvements and Backend Stability**: A refined interface with smoother interactions, improved layouts, and backend stability enhancements for a more reliable, polished experience. + +### Fixed + +- **🛠️ Temporary Chat Message Continue Button Fixed**: The "Continue Response" button for temporary chats now works as expected, ensuring an uninterrupted conversation flow. + +### Changed + +- **📝 Prompt Variable Update**: Deprecated square bracket '[]' indicators for prompt variables; now requires double curly brackets '{{}}' for consistency and clarity. +- **🔧 Stability Enhancements**: Error handling improved in chat history, ensuring smoother operations when reviewing previous messages. + ## [0.5.12] - 2025-02-13 ### Added diff --git a/README.md b/README.md index 56ab09b05d..54ad41503d 100644 --- a/README.md +++ b/README.md @@ -13,10 +13,15 @@ **Open WebUI is an [extensible](https://docs.openwebui.com/features/plugin/), feature-rich, and user-friendly self-hosted AI platform designed to operate entirely offline.** It supports various LLM runners like **Ollama** and **OpenAI-compatible APIs**, with **built-in inference engine** for RAG, making it a **powerful AI deployment solution**. -For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/). - ![Open WebUI Demo](./demo.gif) +> [!TIP] +> **Looking for an [Enterprise Plan](https://docs.openwebui.com/enterprise)?** – **[Speak with Our Sales Team Today!](mailto:sales@openwebui.com)** +> +> Get **enhanced capabilities**, including **custom theming and branding**, **Service Level Agreement (SLA) support**, **Long-Term Support (LTS) versions**, and **more!** + +For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/). + ## Key Features of Open WebUI ⭐ - 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience with support for both `:ollama` and `:cuda` tagged images. diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index d62ac3114e..19abbf1436 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -9,7 +9,6 @@ from pathlib import Path from typing import Generic, Optional, TypeVar from urllib.parse import urlparse -import chromadb import requests from pydantic import BaseModel from sqlalchemy import JSON, Column, DateTime, Integer, func @@ -44,7 +43,7 @@ logging.getLogger("uvicorn.access").addFilter(EndpointFilter()) # Function to run the alembic migrations def run_migrations(): - print("Running migrations") + log.info("Running migrations") try: from alembic import command from alembic.config import Config @@ -57,7 +56,7 @@ def run_migrations(): command.upgrade(alembic_cfg, "head") except Exception as e: - print(f"Error: {e}") + log.exception(f"Error running migrations: {e}") run_migrations() @@ -588,20 +587,6 @@ load_oauth_providers() STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")).resolve() - -def override_static(path: str, content: str): - # Ensure path is safe - if "/" in path or ".." in path: - log.error(f"Invalid path: {path}") - return - - file_path = os.path.join(STATIC_DIR, path) - os.makedirs(os.path.dirname(file_path), exist_ok=True) - - with open(file_path, "wb") as f: - f.write(base64.b64decode(content)) # Convert Base64 back to raw binary - - frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png" if frontend_favicon.exists(): @@ -692,12 +677,20 @@ S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None) S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None) S3_KEY_PREFIX = os.environ.get("S3_KEY_PREFIX", None) S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None) +S3_USE_ACCELERATE_ENDPOINT = ( + os.environ.get("S3_USE_ACCELERATE_ENDPOINT", "False").lower() == "true" +) +S3_ADDRESSING_STYLE = os.environ.get("S3_ADDRESSING_STYLE", None) GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None) GOOGLE_APPLICATION_CREDENTIALS_JSON = os.environ.get( "GOOGLE_APPLICATION_CREDENTIALS_JSON", None ) +AZURE_STORAGE_ENDPOINT = os.environ.get("AZURE_STORAGE_ENDPOINT", None) +AZURE_STORAGE_CONTAINER_NAME = os.environ.get("AZURE_STORAGE_CONTAINER_NAME", None) +AZURE_STORAGE_KEY = os.environ.get("AZURE_STORAGE_KEY", None) + #################################### # File Upload DIR #################################### @@ -797,6 +790,9 @@ ENABLE_OPENAI_API = PersistentConfig( OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") +GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "") +GEMINI_API_BASE_URL = os.environ.get("GEMINI_API_BASE_URL", "") + if OPENAI_API_BASE_URL == "": OPENAI_API_BASE_URL = "https://api.openai.com/v1" @@ -1101,7 +1097,7 @@ try: banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]")) banners = [BannerModel(**banner) for banner in banners] except Exception as e: - print(f"Error loading WEBUI_BANNERS: {e}") + log.exception(f"Error loading WEBUI_BANNERS: {e}") banners = [] WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners) @@ -1377,6 +1373,44 @@ Responses from models: {{responses}}""" # Code Interpreter #################################### + +CODE_EXECUTION_ENGINE = PersistentConfig( + "CODE_EXECUTION_ENGINE", + "code_execution.engine", + os.environ.get("CODE_EXECUTION_ENGINE", "pyodide"), +) + +CODE_EXECUTION_JUPYTER_URL = PersistentConfig( + "CODE_EXECUTION_JUPYTER_URL", + "code_execution.jupyter.url", + os.environ.get("CODE_EXECUTION_JUPYTER_URL", ""), +) + +CODE_EXECUTION_JUPYTER_AUTH = PersistentConfig( + "CODE_EXECUTION_JUPYTER_AUTH", + "code_execution.jupyter.auth", + os.environ.get("CODE_EXECUTION_JUPYTER_AUTH", ""), +) + +CODE_EXECUTION_JUPYTER_AUTH_TOKEN = PersistentConfig( + "CODE_EXECUTION_JUPYTER_AUTH_TOKEN", + "code_execution.jupyter.auth_token", + os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_TOKEN", ""), +) + + +CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = PersistentConfig( + "CODE_EXECUTION_JUPYTER_AUTH_PASSWORD", + "code_execution.jupyter.auth_password", + os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_PASSWORD", ""), +) + +CODE_EXECUTION_JUPYTER_TIMEOUT = PersistentConfig( + "CODE_EXECUTION_JUPYTER_TIMEOUT", + "code_execution.jupyter.timeout", + int(os.environ.get("CODE_EXECUTION_JUPYTER_TIMEOUT", "60")), +) + ENABLE_CODE_INTERPRETER = PersistentConfig( "ENABLE_CODE_INTERPRETER", "code_interpreter.enable", @@ -1398,26 +1432,48 @@ CODE_INTERPRETER_PROMPT_TEMPLATE = PersistentConfig( CODE_INTERPRETER_JUPYTER_URL = PersistentConfig( "CODE_INTERPRETER_JUPYTER_URL", "code_interpreter.jupyter.url", - os.environ.get("CODE_INTERPRETER_JUPYTER_URL", ""), + os.environ.get( + "CODE_INTERPRETER_JUPYTER_URL", os.environ.get("CODE_EXECUTION_JUPYTER_URL", "") + ), ) CODE_INTERPRETER_JUPYTER_AUTH = PersistentConfig( "CODE_INTERPRETER_JUPYTER_AUTH", "code_interpreter.jupyter.auth", - os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH", ""), + os.environ.get( + "CODE_INTERPRETER_JUPYTER_AUTH", + os.environ.get("CODE_EXECUTION_JUPYTER_AUTH", ""), + ), ) CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = PersistentConfig( "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN", "code_interpreter.jupyter.auth_token", - os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_TOKEN", ""), + os.environ.get( + "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN", + os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_TOKEN", ""), + ), ) CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = PersistentConfig( "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD", "code_interpreter.jupyter.auth_password", - os.environ.get("CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD", ""), + os.environ.get( + "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD", + os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_PASSWORD", ""), + ), +) + +CODE_INTERPRETER_JUPYTER_TIMEOUT = PersistentConfig( + "CODE_INTERPRETER_JUPYTER_TIMEOUT", + "code_interpreter.jupyter.timeout", + int( + os.environ.get( + "CODE_INTERPRETER_JUPYTER_TIMEOUT", + os.environ.get("CODE_EXECUTION_JUPYTER_TIMEOUT", "60"), + ) + ), ) @@ -1445,21 +1501,27 @@ VECTOR_DB = os.environ.get("VECTOR_DB", "chroma") # Chroma CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" -CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT) -CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE) -CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "") -CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000")) -CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get("CHROMA_CLIENT_AUTH_PROVIDER", "") -CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get("CHROMA_CLIENT_AUTH_CREDENTIALS", "") -# Comma-separated list of header=value pairs -CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "") -if CHROMA_HTTP_HEADERS: - CHROMA_HTTP_HEADERS = dict( - [pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")] + +if VECTOR_DB == "chroma": + import chromadb + + CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT) + CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE) + CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "") + CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000")) + CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get("CHROMA_CLIENT_AUTH_PROVIDER", "") + CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get( + "CHROMA_CLIENT_AUTH_CREDENTIALS", "" ) -else: - CHROMA_HTTP_HEADERS = None -CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" + # Comma-separated list of header=value pairs + CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "") + if CHROMA_HTTP_HEADERS: + CHROMA_HTTP_HEADERS = dict( + [pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")] + ) + else: + CHROMA_HTTP_HEADERS = None + CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) # Milvus @@ -1513,6 +1575,18 @@ GOOGLE_DRIVE_API_KEY = PersistentConfig( os.environ.get("GOOGLE_DRIVE_API_KEY", ""), ) +ENABLE_ONEDRIVE_INTEGRATION = PersistentConfig( + "ENABLE_ONEDRIVE_INTEGRATION", + "onedrive.enable", + os.getenv("ENABLE_ONEDRIVE_INTEGRATION", "False").lower() == "true", +) + +ONEDRIVE_CLIENT_ID = PersistentConfig( + "ONEDRIVE_CLIENT_ID", + "onedrive.client_id", + os.environ.get("ONEDRIVE_CLIENT_ID", ""), +) + # RAG Content Extraction CONTENT_EXTRACTION_ENGINE = PersistentConfig( "CONTENT_EXTRACTION_ENGINE", @@ -1526,6 +1600,26 @@ TIKA_SERVER_URL = PersistentConfig( os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment ) +DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig( + "DOCUMENT_INTELLIGENCE_ENDPOINT", + "rag.document_intelligence_endpoint", + os.getenv("DOCUMENT_INTELLIGENCE_ENDPOINT", ""), +) + +DOCUMENT_INTELLIGENCE_KEY = PersistentConfig( + "DOCUMENT_INTELLIGENCE_KEY", + "rag.document_intelligence_key", + os.getenv("DOCUMENT_INTELLIGENCE_KEY", ""), +) + + +BYPASS_EMBEDDING_AND_RETRIEVAL = PersistentConfig( + "BYPASS_EMBEDDING_AND_RETRIEVAL", + "rag.bypass_embedding_and_retrieval", + os.environ.get("BYPASS_EMBEDDING_AND_RETRIEVAL", "False").lower() == "true", +) + + RAG_TOP_K = PersistentConfig( "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3")) ) @@ -1541,6 +1635,12 @@ ENABLE_RAG_HYBRID_SEARCH = PersistentConfig( os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true", ) +RAG_FULL_CONTEXT = PersistentConfig( + "RAG_FULL_CONTEXT", + "rag.full_context", + os.getenv("RAG_FULL_CONTEXT", "False").lower() == "true", +) + RAG_FILE_MAX_COUNT = PersistentConfig( "RAG_FILE_MAX_COUNT", "rag.file.max_count", @@ -1655,7 +1755,7 @@ Respond to the user query using the provided context, incorporating inline citat - Respond in the same language as the user's query. - If the context is unreadable or of poor quality, inform the user and provide the best possible answer. - If the answer isn't present in the context but you possess the knowledge, explain this to the user and provide the answer using your own understanding. -- **Only include inline citations using [source_id] when a tag is explicitly provided in the context.** +- **Only include inline citations using [source_id] (e.g., [1], [2]) when a `` tag is explicitly provided in the context.** - Do not cite if the tag is not provided in the context. - Do not use XML tags in your response. - Ensure citations are concise and directly related to the information provided. @@ -1736,6 +1836,12 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig( os.getenv("RAG_WEB_SEARCH_ENGINE", ""), ) +BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = PersistentConfig( + "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", + "rag.web.search.bypass_embedding_and_retrieval", + os.getenv("BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", "False").lower() == "true", +) + # You can provide a list of your own websites to filter after performing a web search. # This ensures the highest level of safety and reliability of the information sources. RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig( @@ -1883,10 +1989,34 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig( int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")), ) +RAG_WEB_LOADER_ENGINE = PersistentConfig( + "RAG_WEB_LOADER_ENGINE", + "rag.web.loader.engine", + os.environ.get("RAG_WEB_LOADER_ENGINE", "safe_web"), +) + RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig( "RAG_WEB_SEARCH_TRUST_ENV", "rag.web.search.trust_env", - os.getenv("RAG_WEB_SEARCH_TRUST_ENV", False), + os.getenv("RAG_WEB_SEARCH_TRUST_ENV", "False").lower() == "true", +) + +PLAYWRIGHT_WS_URI = PersistentConfig( + "PLAYWRIGHT_WS_URI", + "rag.web.loader.engine.playwright.ws.uri", + os.environ.get("PLAYWRIGHT_WS_URI", None), +) + +FIRECRAWL_API_KEY = PersistentConfig( + "FIRECRAWL_API_KEY", + "firecrawl.api_key", + os.environ.get("FIRECRAWL_API_KEY", ""), +) + +FIRECRAWL_API_BASE_URL = PersistentConfig( + "FIRECRAWL_API_BASE_URL", + "firecrawl.api_url", + os.environ.get("FIRECRAWL_API_BASE_URL", "https://api.firecrawl.dev"), ) #################################### @@ -2099,6 +2229,17 @@ IMAGES_OPENAI_API_KEY = PersistentConfig( os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY), ) +IMAGES_GEMINI_API_BASE_URL = PersistentConfig( + "IMAGES_GEMINI_API_BASE_URL", + "image_generation.gemini.api_base_url", + os.getenv("IMAGES_GEMINI_API_BASE_URL", GEMINI_API_BASE_URL), +) +IMAGES_GEMINI_API_KEY = PersistentConfig( + "IMAGES_GEMINI_API_KEY", + "image_generation.gemini.api_key", + os.getenv("IMAGES_GEMINI_API_KEY", GEMINI_API_KEY), +) + IMAGE_SIZE = PersistentConfig( "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512") ) @@ -2275,7 +2416,7 @@ LDAP_SEARCH_BASE = PersistentConfig( LDAP_SEARCH_FILTERS = PersistentConfig( "LDAP_SEARCH_FILTER", "ldap.server.search_filter", - os.environ.get("LDAP_SEARCH_FILTER", ""), + os.environ.get("LDAP_SEARCH_FILTER", os.environ.get("LDAP_SEARCH_FILTERS", "")), ) LDAP_USE_TLS = PersistentConfig( diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 96e288d777..ba546a2eb5 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -419,3 +419,25 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true" if OFFLINE_MODE: os.environ["HF_HUB_OFFLINE"] = "1" + +#################################### +# AUDIT LOGGING +#################################### +ENABLE_AUDIT_LOGS = os.getenv("ENABLE_AUDIT_LOGS", "false").lower() == "true" +# Where to store log file +AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log" +# Maximum size of a file before rotating into a new log file +AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB") +# METADATA | REQUEST | REQUEST_RESPONSE +AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "REQUEST_RESPONSE").upper() +try: + MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048) +except ValueError: + MAX_BODY_LOG_SIZE = 2048 + +# Comma separated list for urls to exclude from audit +AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split( + "," +) +AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS] +AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS] diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 274be56ec0..2f94f701e9 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -2,6 +2,7 @@ import logging import sys import inspect import json +import asyncio from pydantic import BaseModel from typing import AsyncGenerator, Generator, Iterator @@ -76,11 +77,13 @@ async def get_function_models(request): if hasattr(function_module, "pipes"): sub_pipes = [] - # Check if pipes is a function or a list - + # Handle pipes being a list, sync function, or async function try: if callable(function_module.pipes): - sub_pipes = function_module.pipes() + if asyncio.iscoroutinefunction(function_module.pipes): + sub_pipes = await function_module.pipes() + else: + sub_pipes = function_module.pipes() else: sub_pipes = function_module.pipes except Exception as e: diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index a6fa6bd8c4..9676d144e9 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -45,6 +45,9 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.responses import Response, StreamingResponse +from open_webui.utils import logger +from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware +from open_webui.utils.logger import start_logger from open_webui.socket.main import ( app as socket_app, periodic_usage_pool_cleanup, @@ -95,12 +98,19 @@ from open_webui.config import ( OLLAMA_API_CONFIGS, # OpenAI ENABLE_OPENAI_API, + ONEDRIVE_CLIENT_ID, OPENAI_API_BASE_URLS, OPENAI_API_KEYS, OPENAI_API_CONFIGS, # Direct Connections ENABLE_DIRECT_CONNECTIONS, - # Code Interpreter + # Code Execution + CODE_EXECUTION_ENGINE, + CODE_EXECUTION_JUPYTER_URL, + CODE_EXECUTION_JUPYTER_AUTH, + CODE_EXECUTION_JUPYTER_AUTH_TOKEN, + CODE_EXECUTION_JUPYTER_AUTH_PASSWORD, + CODE_EXECUTION_JUPYTER_TIMEOUT, ENABLE_CODE_INTERPRETER, CODE_INTERPRETER_ENGINE, CODE_INTERPRETER_PROMPT_TEMPLATE, @@ -108,6 +118,7 @@ from open_webui.config import ( CODE_INTERPRETER_JUPYTER_AUTH, CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, + CODE_INTERPRETER_JUPYTER_TIMEOUT, # Image AUTOMATIC1111_API_AUTH, AUTOMATIC1111_BASE_URL, @@ -126,6 +137,8 @@ from open_webui.config import ( IMAGE_STEPS, IMAGES_OPENAI_API_BASE_URL, IMAGES_OPENAI_API_KEY, + IMAGES_GEMINI_API_BASE_URL, + IMAGES_GEMINI_API_KEY, # Audio AUDIO_STT_ENGINE, AUDIO_STT_MODEL, @@ -140,6 +153,10 @@ from open_webui.config import ( AUDIO_TTS_VOICE, AUDIO_TTS_AZURE_SPEECH_REGION, AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT, + PLAYWRIGHT_WS_URI, + FIRECRAWL_API_BASE_URL, + FIRECRAWL_API_KEY, + RAG_WEB_LOADER_ENGINE, WHISPER_MODEL, DEEPGRAM_API_KEY, WHISPER_MODEL_AUTO_UPDATE, @@ -147,6 +164,8 @@ from open_webui.config import ( # Retrieval RAG_TEMPLATE, DEFAULT_RAG_TEMPLATE, + RAG_FULL_CONTEXT, + BYPASS_EMBEDDING_AND_RETRIEVAL, RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, @@ -166,6 +185,8 @@ from open_webui.config import ( CHUNK_SIZE, CONTENT_EXTRACTION_ENGINE, TIKA_SERVER_URL, + DOCUMENT_INTELLIGENCE_ENDPOINT, + DOCUMENT_INTELLIGENCE_KEY, RAG_TOP_K, RAG_TEXT_SPLITTER, TIKTOKEN_ENCODING_NAME, @@ -174,6 +195,7 @@ from open_webui.config import ( YOUTUBE_LOADER_PROXY_URL, # Retrieval (Web Search) RAG_WEB_SEARCH_ENGINE, + BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, RAG_WEB_SEARCH_RESULT_COUNT, RAG_WEB_SEARCH_CONCURRENT_REQUESTS, RAG_WEB_SEARCH_TRUST_ENV, @@ -200,11 +222,13 @@ from open_webui.config import ( GOOGLE_PSE_ENGINE_ID, GOOGLE_DRIVE_CLIENT_ID, GOOGLE_DRIVE_API_KEY, + ONEDRIVE_CLIENT_ID, ENABLE_RAG_HYBRID_SEARCH, ENABLE_RAG_LOCAL_WEB_FETCH, ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, ENABLE_RAG_WEB_SEARCH, ENABLE_GOOGLE_DRIVE_INTEGRATION, + ENABLE_ONEDRIVE_INTEGRATION, UPLOAD_DIR, # WebUI WEBUI_AUTH, @@ -283,8 +307,11 @@ from open_webui.config import ( reset_config, ) from open_webui.env import ( + AUDIT_EXCLUDED_PATHS, + AUDIT_LOG_LEVEL, CHANGELOG, GLOBAL_LOG_LEVEL, + MAX_BODY_LOG_SIZE, SAFE_MODE, SRC_LOG_LEVELS, VERSION, @@ -369,6 +396,7 @@ https://github.com/open-webui/open-webui @asynccontextmanager async def lifespan(app: FastAPI): + start_logger() if RESET_CONFIG_ON_START: reset_config() @@ -509,6 +537,9 @@ app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT + +app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT +app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = BYPASS_EMBEDDING_AND_RETRIEVAL app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION @@ -516,6 +547,8 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL +app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT +app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME @@ -543,9 +576,13 @@ app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE +app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = ( + BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL +) app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION +app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID @@ -569,7 +606,11 @@ app.state.config.EXA_API_KEY = EXA_API_KEY app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS +app.state.config.RAG_WEB_LOADER_ENGINE = RAG_WEB_LOADER_ENGINE app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV +app.state.config.PLAYWRIGHT_WS_URI = PLAYWRIGHT_WS_URI +app.state.config.FIRECRAWL_API_BASE_URL = FIRECRAWL_API_BASE_URL +app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY app.state.EMBEDDING_FUNCTION = None app.state.ef = None @@ -613,10 +654,19 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function( ######################################## # -# CODE INTERPRETER +# CODE EXECUTION # ######################################## +app.state.config.CODE_EXECUTION_ENGINE = CODE_EXECUTION_ENGINE +app.state.config.CODE_EXECUTION_JUPYTER_URL = CODE_EXECUTION_JUPYTER_URL +app.state.config.CODE_EXECUTION_JUPYTER_AUTH = CODE_EXECUTION_JUPYTER_AUTH +app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = CODE_EXECUTION_JUPYTER_AUTH_TOKEN +app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = ( + CODE_EXECUTION_JUPYTER_AUTH_PASSWORD +) +app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = CODE_EXECUTION_JUPYTER_TIMEOUT + app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER app.state.config.CODE_INTERPRETER_ENGINE = CODE_INTERPRETER_ENGINE app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = CODE_INTERPRETER_PROMPT_TEMPLATE @@ -629,6 +679,7 @@ app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = ( app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = ( CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD ) +app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = CODE_INTERPRETER_JUPYTER_TIMEOUT ######################################## # @@ -643,6 +694,9 @@ app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ENABLE_IMAGE_PROMPT_GENERATION app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY +app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL +app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY + app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL @@ -844,6 +898,19 @@ app.include_router( app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"]) +try: + audit_level = AuditLevel(AUDIT_LOG_LEVEL) +except ValueError as e: + logger.error(f"Invalid audit level: {AUDIT_LOG_LEVEL}. Error: {e}") + audit_level = AuditLevel.NONE + +if audit_level != AuditLevel.NONE: + app.add_middleware( + AuditLoggingMiddleware, + audit_level=audit_level, + excluded_paths=AUDIT_EXCLUDED_PATHS, + max_body_size=MAX_BODY_LOG_SIZE, + ) ################################## # # Chat Endpoints @@ -876,7 +943,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)): return filtered_models - models = await get_all_models(request) + models = await get_all_models(request, user=user) # Filter out filter pipelines models = [ @@ -905,7 +972,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)): @app.get("/api/models/base") async def get_base_models(request: Request, user=Depends(get_admin_user)): - models = await get_all_base_models(request) + models = await get_all_base_models(request, user=user) return {"data": models} @@ -916,7 +983,7 @@ async def chat_completion( user=Depends(get_verified_user), ): if not request.app.state.MODELS: - await get_all_models(request) + await get_all_models(request, user=user) model_item = form_data.pop("model_item", {}) tasks = form_data.pop("background_tasks", None) @@ -952,7 +1019,7 @@ async def chat_completion( "files": form_data.get("files", None), "features": form_data.get("features", None), "variables": form_data.get("variables", None), - "model": model_info, + "model": model_info.model_dump() if model_info else model, "direct": model_item.get("direct", False), **( {"function_calling": "native"} @@ -1111,6 +1178,7 @@ async def get_app_config(request: Request): "enable_admin_export": ENABLE_ADMIN_EXPORT, "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, "enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, + "enable_onedrive_integration": app.state.config.ENABLE_ONEDRIVE_INTEGRATION, } if user is not None else {} @@ -1120,6 +1188,9 @@ async def get_app_config(request: Request): { "default_models": app.state.config.DEFAULT_MODELS, "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, + "code": { + "engine": app.state.config.CODE_EXECUTION_ENGINE, + }, "audio": { "tts": { "engine": app.state.config.TTS_ENGINE, @@ -1139,6 +1210,7 @@ async def get_app_config(request: Request): "client_id": GOOGLE_DRIVE_CLIENT_ID.value, "api_key": GOOGLE_DRIVE_API_KEY.value, }, + "onedrive": {"client_id": ONEDRIVE_CLIENT_ID.value}, } if user is not None else {} diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index 9e0a5865e9..a222d221c0 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -1,3 +1,4 @@ +import logging import json import time import uuid @@ -5,7 +6,7 @@ from typing import Optional from open_webui.internal.db import Base, get_db from open_webui.models.tags import TagModel, Tag, Tags - +from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON @@ -16,6 +17,9 @@ from sqlalchemy.sql import exists # Chat DB Schema #################### +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + class Chat(Base): __tablename__ = "chat" @@ -670,7 +674,7 @@ class ChatTable: # Perform pagination at the SQL level all_chats = query.offset(skip).limit(limit).all() - print(len(all_chats)) + log.info(f"The number of chats: {len(all_chats)}") # Validate and return chats return [ChatModel.model_validate(chat) for chat in all_chats] @@ -731,7 +735,7 @@ class ChatTable: query = db.query(Chat).filter_by(user_id=user_id) tag_id = tag_name.replace(" ", "_").lower() - print(db.bind.dialect.name) + log.info(f"DB dialect name: {db.bind.dialect.name}") if db.bind.dialect.name == "sqlite": # SQLite JSON1 querying for tags within the meta JSON field query = query.filter( @@ -752,7 +756,7 @@ class ChatTable: ) all_chats = query.all() - print("all_chats", all_chats) + log.debug(f"all_chats: {all_chats}") return [ChatModel.model_validate(chat) for chat in all_chats] def add_chat_tag_by_id_and_user_id_and_tag_name( @@ -810,7 +814,7 @@ class ChatTable: count = query.count() # Debugging output for inspection - print(f"Count of chats for tag '{tag_name}':", count) + log.info(f"Count of chats for tag '{tag_name}': {count}") return count diff --git a/backend/open_webui/models/feedbacks.py b/backend/open_webui/models/feedbacks.py index 7ff5c45408..215e36aa24 100644 --- a/backend/open_webui/models/feedbacks.py +++ b/backend/open_webui/models/feedbacks.py @@ -118,7 +118,7 @@ class FeedbackTable: else: return None except Exception as e: - print(e) + log.exception(f"Error creating a new feedback: {e}") return None def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]: diff --git a/backend/open_webui/models/files.py b/backend/open_webui/models/files.py index 91dea54443..6f1511cd13 100644 --- a/backend/open_webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -119,7 +119,7 @@ class FilesTable: else: return None except Exception as e: - print(f"Error creating tool: {e}") + log.exception(f"Error inserting a new file: {e}") return None def get_file_by_id(self, id: str) -> Optional[FileModel]: diff --git a/backend/open_webui/models/folders.py b/backend/open_webui/models/folders.py index 040774196b..19739bc5f5 100644 --- a/backend/open_webui/models/folders.py +++ b/backend/open_webui/models/folders.py @@ -82,7 +82,7 @@ class FolderTable: else: return None except Exception as e: - print(e) + log.exception(f"Error inserting a new folder: {e}") return None def get_folder_by_id_and_user_id( diff --git a/backend/open_webui/models/functions.py b/backend/open_webui/models/functions.py index 6c6aed8623..8cbfc5de7d 100644 --- a/backend/open_webui/models/functions.py +++ b/backend/open_webui/models/functions.py @@ -105,7 +105,7 @@ class FunctionsTable: else: return None except Exception as e: - print(f"Error creating tool: {e}") + log.exception(f"Error creating a new function: {e}") return None def get_function_by_id(self, id: str) -> Optional[FunctionModel]: @@ -170,7 +170,7 @@ class FunctionsTable: function = db.get(Function, id) return function.valves if function.valves else {} except Exception as e: - print(f"An error occurred: {e}") + log.exception(f"Error getting function valves by id {id}: {e}") return None def update_function_valves_by_id( @@ -202,7 +202,9 @@ class FunctionsTable: return user_settings["functions"]["valves"].get(id, {}) except Exception as e: - print(f"An error occurred: {e}") + log.exception( + f"Error getting user values by id {id} and user id {user_id}: {e}" + ) return None def update_user_valves_by_id_and_user_id( @@ -225,7 +227,9 @@ class FunctionsTable: return user_settings["functions"]["valves"][id] except Exception as e: - print(f"An error occurred: {e}") + log.exception( + f"Error updating user valves by id {id} and user_id {user_id}: {e}" + ) return None def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py old mode 100644 new mode 100755 index f2f59d7c49..7df8d8656b --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -166,7 +166,7 @@ class ModelsTable: else: return None except Exception as e: - print(e) + log.exception(f"Failed to insert a new model: {e}") return None def get_all_models(self) -> list[ModelModel]: @@ -246,8 +246,7 @@ class ModelsTable: db.refresh(model) return ModelModel.model_validate(model) except Exception as e: - print(e) - + log.exception(f"Failed to update the model by id {id}: {e}") return None def delete_model_by_id(self, id: str) -> bool: diff --git a/backend/open_webui/models/tags.py b/backend/open_webui/models/tags.py index 3e812db95d..279dc624d5 100644 --- a/backend/open_webui/models/tags.py +++ b/backend/open_webui/models/tags.py @@ -61,7 +61,7 @@ class TagTable: else: return None except Exception as e: - print(e) + log.exception(f"Error inserting a new tag: {e}") return None def get_tag_by_name_and_user_id( diff --git a/backend/open_webui/models/tools.py b/backend/open_webui/models/tools.py index a5f13ebb71..68a83ea42c 100644 --- a/backend/open_webui/models/tools.py +++ b/backend/open_webui/models/tools.py @@ -131,7 +131,7 @@ class ToolsTable: else: return None except Exception as e: - print(f"Error creating tool: {e}") + log.exception(f"Error creating a new tool: {e}") return None def get_tool_by_id(self, id: str) -> Optional[ToolModel]: @@ -175,7 +175,7 @@ class ToolsTable: tool = db.get(Tool, id) return tool.valves if tool.valves else {} except Exception as e: - print(f"An error occurred: {e}") + log.exception(f"Error getting tool valves by id {id}: {e}") return None def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: @@ -204,7 +204,9 @@ class ToolsTable: return user_settings["tools"]["valves"].get(id, {}) except Exception as e: - print(f"An error occurred: {e}") + log.exception( + f"Error getting user values by id {id} and user_id {user_id}: {e}" + ) return None def update_user_valves_by_id_and_user_id( @@ -227,7 +229,9 @@ class ToolsTable: return user_settings["tools"]["valves"][id] except Exception as e: - print(f"An error occurred: {e}") + log.exception( + f"Error updating user valves by id {id} and user_id {user_id}: {e}" + ) return None def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: diff --git a/backend/open_webui/retrieval/loaders/main.py b/backend/open_webui/retrieval/loaders/main.py index a9372f65a6..7fa24ced37 100644 --- a/backend/open_webui/retrieval/loaders/main.py +++ b/backend/open_webui/retrieval/loaders/main.py @@ -4,6 +4,7 @@ import ftfy import sys from langchain_community.document_loaders import ( + AzureAIDocumentIntelligenceLoader, BSHTMLLoader, CSVLoader, Docx2txtLoader, @@ -76,6 +77,7 @@ known_source_ext = [ "jsx", "hs", "lhs", + "json", ] @@ -147,6 +149,27 @@ class Loader: file_path=file_path, mime_type=file_content_type, ) + elif ( + self.engine == "document_intelligence" + and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != "" + and self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != "" + and ( + file_ext in ["pdf", "xls", "xlsx", "docx", "ppt", "pptx"] + or file_content_type + in [ + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.ms-powerpoint", + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ] + ) + ): + loader = AzureAIDocumentIntelligenceLoader( + file_path=file_path, + api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"), + api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"), + ) else: if file_ext == "pdf": loader = PyPDFLoader( diff --git a/backend/open_webui/retrieval/models/colbert.py b/backend/open_webui/retrieval/models/colbert.py index ea3204cb8b..5b7499fd18 100644 --- a/backend/open_webui/retrieval/models/colbert.py +++ b/backend/open_webui/retrieval/models/colbert.py @@ -1,13 +1,19 @@ import os +import logging import torch import numpy as np from colbert.infra import ColBERTConfig from colbert.modeling.checkpoint import Checkpoint +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + class ColBERT: def __init__(self, name, **kwargs) -> None: - print("ColBERT: Loading model", name) + log.info("ColBERT: Loading model", name) self.device = "cuda" if torch.cuda.is_available() else "cpu" DOCKER = kwargs.get("env") == "docker" diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index b7da20ad5f..b6253e63cc 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -5,6 +5,7 @@ from typing import Optional, Union import asyncio import requests +import hashlib from huggingface_hub import snapshot_download from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever @@ -14,8 +15,10 @@ from langchain_core.documents import Document from open_webui.config import VECTOR_DB from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT -from open_webui.utils.misc import get_last_user_message +from open_webui.utils.misc import get_last_user_message, calculate_sha256_string + from open_webui.models.users import UserModel +from open_webui.models.files import Files from open_webui.env import ( SRC_LOG_LEVELS, @@ -80,7 +83,20 @@ def query_doc( return result except Exception as e: - print(e) + log.exception(f"Error querying doc {collection_name} with limit {k}: {e}") + raise e + + +def get_doc(collection_name: str, user: UserModel = None): + try: + result = VECTOR_DB_CLIENT.get(collection_name=collection_name) + + if result: + log.info(f"query_doc:result {result.ids} {result.metadatas}") + + return result + except Exception as e: + log.exception(f"Error getting doc {collection_name}: {e}") raise e @@ -137,47 +153,80 @@ def query_doc_with_hybrid_search( raise e -def merge_and_sort_query_results( - query_results: list[dict], k: int, reverse: bool = False -) -> list[dict]: +def merge_get_results(get_results: list[dict]) -> dict: # Initialize lists to store combined data - combined_distances = [] combined_documents = [] combined_metadatas = [] + combined_ids = [] - for data in query_results: - combined_distances.extend(data["distances"][0]) + for data in get_results: combined_documents.extend(data["documents"][0]) combined_metadatas.extend(data["metadatas"][0]) + combined_ids.extend(data["ids"][0]) - # Create a list of tuples (distance, document, metadata) - combined = list(zip(combined_distances, combined_documents, combined_metadatas)) + # Create the output dictionary + result = { + "documents": [combined_documents], + "metadatas": [combined_metadatas], + "ids": [combined_ids], + } + + return result + + +def merge_and_sort_query_results( + query_results: list[dict], k: int, reverse: bool = False +) -> dict: + # Initialize lists to store combined data + combined = [] + seen_hashes = set() # To store unique document hashes + + for data in query_results: + distances = data["distances"][0] + documents = data["documents"][0] + metadatas = data["metadatas"][0] + + for distance, document, metadata in zip(distances, documents, metadatas): + if isinstance(document, str): + doc_hash = hashlib.md5( + document.encode() + ).hexdigest() # Compute a hash for uniqueness + + if doc_hash not in seen_hashes: + seen_hashes.add(doc_hash) + combined.append((distance, document, metadata)) # Sort the list based on distances combined.sort(key=lambda x: x[0], reverse=reverse) - # We don't have anything :-( - if not combined: - sorted_distances = [] - sorted_documents = [] - sorted_metadatas = [] - else: - # Unzip the sorted list - sorted_distances, sorted_documents, sorted_metadatas = zip(*combined) + # Slice to keep only the top k elements + sorted_distances, sorted_documents, sorted_metadatas = ( + zip(*combined[:k]) if combined else ([], [], []) + ) - # Slicing the lists to include only k elements - sorted_distances = list(sorted_distances)[:k] - sorted_documents = list(sorted_documents)[:k] - sorted_metadatas = list(sorted_metadatas)[:k] - - # Create the output dictionary - result = { - "distances": [sorted_distances], - "documents": [sorted_documents], - "metadatas": [sorted_metadatas], + # Create and return the output dictionary + return { + "distances": [list(sorted_distances)], + "documents": [list(sorted_documents)], + "metadatas": [list(sorted_metadatas)], } - return result + +def get_all_items_from_collections(collection_names: list[str]) -> dict: + results = [] + + for collection_name in collection_names: + if collection_name: + try: + result = get_doc(collection_name=collection_name) + if result is not None: + results.append(result.model_dump()) + except Exception as e: + log.exception(f"Error when querying the collection: {e}") + else: + pass + + return merge_get_results(results) def query_collection( @@ -290,6 +339,7 @@ def get_embedding_function( def get_sources_from_files( + request, files, queries, embedding_function, @@ -297,21 +347,74 @@ def get_sources_from_files( reranking_function, r, hybrid_search, + full_context=False, ): - log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}") + log.debug( + f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}" + ) extracted_collections = [] relevant_contexts = [] for file in files: - if file.get("context") == "full": + + context = None + if file.get("docs"): + # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL + context = { + "documents": [[doc.get("content") for doc in file.get("docs")]], + "metadatas": [[doc.get("metadata") for doc in file.get("docs")]], + } + elif file.get("context") == "full": + # Manual Full Mode Toggle context = { "documents": [[file.get("file").get("data", {}).get("content")]], "metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]], } - else: - context = None + elif ( + file.get("type") != "web_search" + and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL + ): + # BYPASS_EMBEDDING_AND_RETRIEVAL + if file.get("type") == "collection": + file_ids = file.get("data", {}).get("file_ids", []) + documents = [] + metadatas = [] + for file_id in file_ids: + file_object = Files.get_file_by_id(file_id) + + if file_object: + documents.append(file_object.data.get("content", "")) + metadatas.append( + { + "file_id": file_id, + "name": file_object.filename, + "source": file_object.filename, + } + ) + + context = { + "documents": [documents], + "metadatas": [metadatas], + } + + elif file.get("id"): + file_object = Files.get_file_by_id(file.get("id")) + if file_object: + context = { + "documents": [[file_object.data.get("content", "")]], + "metadatas": [ + [ + { + "file_id": file.get("id"), + "name": file_object.filename, + "source": file_object.filename, + } + ] + ], + } + else: collection_names = [] if file.get("type") == "collection": if file.get("legacy"): @@ -331,42 +434,50 @@ def get_sources_from_files( log.debug(f"skipping {file} as it has already been extracted") continue - try: - context = None - if file.get("type") == "text": - context = file["content"] - else: - if hybrid_search: - try: - context = query_collection_with_hybrid_search( + if full_context: + try: + context = get_all_items_from_collections(collection_names) + except Exception as e: + log.exception(e) + + else: + try: + context = None + if file.get("type") == "text": + context = file["content"] + else: + if hybrid_search: + try: + context = query_collection_with_hybrid_search( + collection_names=collection_names, + queries=queries, + embedding_function=embedding_function, + k=k, + reranking_function=reranking_function, + r=r, + ) + except Exception as e: + log.debug( + "Error when using hybrid search, using" + " non hybrid search as fallback." + ) + + if (not hybrid_search) or (context is None): + context = query_collection( collection_names=collection_names, queries=queries, embedding_function=embedding_function, k=k, - reranking_function=reranking_function, - r=r, ) - except Exception as e: - log.debug( - "Error when using hybrid search, using" - " non hybrid search as fallback." - ) - - if (not hybrid_search) or (context is None): - context = query_collection( - collection_names=collection_names, - queries=queries, - embedding_function=embedding_function, - k=k, - ) - except Exception as e: - log.exception(e) + except Exception as e: + log.exception(e) extracted_collections.extend(collection_names) if context: if "data" in file: del file["data"] + relevant_contexts.append({**context, "file": file}) sources = [] @@ -463,7 +574,7 @@ def generate_openai_batch_embeddings( else: raise "Something went wrong :/" except Exception as e: - print(e) + log.exception(f"Error generating openai batch embeddings: {e}") return None @@ -497,7 +608,7 @@ def generate_ollama_batch_embeddings( else: raise "Something went wrong :/" except Exception as e: - print(e) + log.exception(f"Error generating ollama batch embeddings: {e}") return None diff --git a/backend/open_webui/retrieval/vector/dbs/chroma.py b/backend/open_webui/retrieval/vector/dbs/chroma.py old mode 100644 new mode 100755 index c40618fcc5..006ee20763 --- a/backend/open_webui/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/retrieval/vector/dbs/chroma.py @@ -1,4 +1,5 @@ import chromadb +import logging from chromadb import Settings from chromadb.utils.batch_utils import create_batches @@ -16,6 +17,10 @@ from open_webui.config import ( CHROMA_CLIENT_AUTH_PROVIDER, CHROMA_CLIENT_AUTH_CREDENTIALS, ) +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) class ChromaClient: @@ -102,8 +107,7 @@ class ChromaClient: } ) return None - except Exception as e: - print(e) + except: return None def get(self, collection_name: str) -> Optional[GetResult]: diff --git a/backend/open_webui/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py index 43c3f3d1a1..3fd5b1ccd3 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus.py @@ -1,7 +1,7 @@ from pymilvus import MilvusClient as Client from pymilvus import FieldSchema, DataType import json - +import logging from typing import Optional from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult @@ -10,6 +10,10 @@ from open_webui.config import ( MILVUS_DB, MILVUS_TOKEN, ) +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) class MilvusClient: @@ -168,7 +172,7 @@ class MilvusClient: try: # Loop until there are no more items to fetch or the desired limit is reached while remaining > 0: - print("remaining", remaining) + log.info(f"remaining: {remaining}") current_fetch = min( max_limit, remaining ) # Determine how many items to fetch in this iteration @@ -195,10 +199,12 @@ class MilvusClient: if results_count < current_fetch: break - print(all_results) + log.debug(all_results) return self._result_to_get_result([all_results]) except Exception as e: - print(e) + log.exception( + f"Error querying collection {collection_name} with limit {limit}: {e}" + ) return None def get(self, collection_name: str) -> Optional[GetResult]: diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index 341b3056fa..eab02232f4 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -1,4 +1,5 @@ from typing import Optional, List, Dict, Any +import logging from sqlalchemy import ( cast, column, @@ -24,9 +25,14 @@ from sqlalchemy.exc import NoSuchTableError from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH +from open_webui.env import SRC_LOG_LEVELS + VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH Base = declarative_base() +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + class DocumentChunk(Base): __tablename__ = "document_chunk" @@ -82,10 +88,10 @@ class PgvectorClient: ) ) self.session.commit() - print("Initialization complete.") + log.info("Initialization complete.") except Exception as e: self.session.rollback() - print(f"Error during initialization: {e}") + log.exception(f"Error during initialization: {e}") raise def check_vector_length(self) -> None: @@ -150,12 +156,12 @@ class PgvectorClient: new_items.append(new_chunk) self.session.bulk_save_objects(new_items) self.session.commit() - print( + log.info( f"Inserted {len(new_items)} items into collection '{collection_name}'." ) except Exception as e: self.session.rollback() - print(f"Error during insert: {e}") + log.exception(f"Error during insert: {e}") raise def upsert(self, collection_name: str, items: List[VectorItem]) -> None: @@ -184,10 +190,12 @@ class PgvectorClient: ) self.session.add(new_chunk) self.session.commit() - print(f"Upserted {len(items)} items into collection '{collection_name}'.") + log.info( + f"Upserted {len(items)} items into collection '{collection_name}'." + ) except Exception as e: self.session.rollback() - print(f"Error during upsert: {e}") + log.exception(f"Error during upsert: {e}") raise def search( @@ -278,7 +286,7 @@ class PgvectorClient: ids=ids, distances=distances, documents=documents, metadatas=metadatas ) except Exception as e: - print(f"Error during search: {e}") + log.exception(f"Error during search: {e}") return None def query( @@ -310,7 +318,7 @@ class PgvectorClient: metadatas=metadatas, ) except Exception as e: - print(f"Error during query: {e}") + log.exception(f"Error during query: {e}") return None def get( @@ -334,7 +342,7 @@ class PgvectorClient: return GetResult(ids=ids, documents=documents, metadatas=metadatas) except Exception as e: - print(f"Error during get: {e}") + log.exception(f"Error during get: {e}") return None def delete( @@ -356,22 +364,22 @@ class PgvectorClient: ) deleted = query.delete(synchronize_session=False) self.session.commit() - print(f"Deleted {deleted} items from collection '{collection_name}'.") + log.info(f"Deleted {deleted} items from collection '{collection_name}'.") except Exception as e: self.session.rollback() - print(f"Error during delete: {e}") + log.exception(f"Error during delete: {e}") raise def reset(self) -> None: try: deleted = self.session.query(DocumentChunk).delete() self.session.commit() - print( + log.info( f"Reset complete. Deleted {deleted} items from 'document_chunk' table." ) except Exception as e: self.session.rollback() - print(f"Error during reset: {e}") + log.exception(f"Error during reset: {e}") raise def close(self) -> None: @@ -387,9 +395,9 @@ class PgvectorClient: ) return exists except Exception as e: - print(f"Error checking collection existence: {e}") + log.exception(f"Error checking collection existence: {e}") return False def delete_collection(self, collection_name: str) -> None: self.delete(collection_name) - print(f"Collection '{collection_name}' deleted.") + log.info(f"Collection '{collection_name}' deleted.") diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py index f077ae45ac..28f0b37793 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py @@ -1,4 +1,5 @@ from typing import Optional +import logging from qdrant_client import QdrantClient as Qclient from qdrant_client.http.models import PointStruct @@ -6,9 +7,13 @@ from qdrant_client.models import models from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import QDRANT_URI, QDRANT_API_KEY +from open_webui.env import SRC_LOG_LEVELS NO_LIMIT = 999999999 +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + class QdrantClient: def __init__(self): @@ -49,7 +54,7 @@ class QdrantClient: ), ) - print(f"collection {collection_name_with_prefix} successfully created!") + log.info(f"collection {collection_name_with_prefix} successfully created!") def _create_collection_if_not_exists(self, collection_name, dimension): if not self.has_collection(collection_name=collection_name): @@ -120,7 +125,7 @@ class QdrantClient: ) return self._result_to_get_result(points.points) except Exception as e: - print(e) + log.exception(f"Error querying a collection '{collection_name}': {e}") return None def get(self, collection_name: str) -> Optional[GetResult]: diff --git a/backend/open_webui/retrieval/web/tavily.py b/backend/open_webui/retrieval/web/tavily.py index 7839b715ee..da70aa8e7f 100644 --- a/backend/open_webui/retrieval/web/tavily.py +++ b/backend/open_webui/retrieval/web/tavily.py @@ -27,8 +27,7 @@ def search_tavily( """ url = "https://api.tavily.com/search" data = {"query": query, "api_key": api_key} - include_domain = filter_list - response = requests.post(url, include_domain, json=data) + response = requests.post(url, json=data) response.raise_for_status() json_response = response.json() diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index 145f1adbca..fd94a1a32f 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -1,22 +1,38 @@ -import socket -import aiohttp import asyncio -import urllib.parse -import validators -from typing import Any, AsyncIterator, Dict, Iterator, List, Sequence, Union - - -from langchain_community.document_loaders import ( - WebBaseLoader, -) -from langchain_core.documents import Document - - -from open_webui.constants import ERROR_MESSAGES -from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH -from open_webui.env import SRC_LOG_LEVELS - import logging +import socket +import ssl +import urllib.parse +import urllib.request +from collections import defaultdict +from datetime import datetime, time, timedelta +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + List, + Optional, + Sequence, + Union, + Literal, +) +import aiohttp +import certifi +import validators +from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader +from langchain_community.document_loaders.firecrawl import FireCrawlLoader +from langchain_community.document_loaders.base import BaseLoader +from langchain_core.documents import Document +from open_webui.constants import ERROR_MESSAGES +from open_webui.config import ( + ENABLE_RAG_LOCAL_WEB_FETCH, + PLAYWRIGHT_WS_URI, + RAG_WEB_LOADER_ENGINE, + FIRECRAWL_API_BASE_URL, + FIRECRAWL_API_KEY, +) +from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -68,6 +84,314 @@ def resolve_hostname(hostname): return ipv4_addresses, ipv6_addresses +def extract_metadata(soup, url): + metadata = {"source": url} + if title := soup.find("title"): + metadata["title"] = title.get_text() + if description := soup.find("meta", attrs={"name": "description"}): + metadata["description"] = description.get("content", "No description found.") + if html := soup.find("html"): + metadata["language"] = html.get("lang", "No language found.") + return metadata + + +def verify_ssl_cert(url: str) -> bool: + """Verify SSL certificate for the given URL.""" + if not url.startswith("https://"): + return True + + try: + hostname = url.split("://")[-1].split("/")[0] + context = ssl.create_default_context(cafile=certifi.where()) + with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s: + s.connect((hostname, 443)) + return True + except ssl.SSLError: + return False + except Exception as e: + log.warning(f"SSL verification failed for {url}: {str(e)}") + return False + + +class SafeFireCrawlLoader(BaseLoader): + def __init__( + self, + web_paths, + verify_ssl: bool = True, + trust_env: bool = False, + requests_per_second: Optional[float] = None, + continue_on_failure: bool = True, + api_key: Optional[str] = None, + api_url: Optional[str] = None, + mode: Literal["crawl", "scrape", "map"] = "crawl", + proxy: Optional[Dict[str, str]] = None, + params: Optional[Dict] = None, + ): + """Concurrent document loader for FireCrawl operations. + + Executes multiple FireCrawlLoader instances concurrently using thread pooling + to improve bulk processing efficiency. + Args: + web_paths: List of URLs/paths to process. + verify_ssl: If True, verify SSL certificates. + trust_env: If True, use proxy settings from environment variables. + requests_per_second: Number of requests per second to limit to. + continue_on_failure (bool): If True, continue loading other URLs on failure. + api_key: API key for FireCrawl service. Defaults to None + (uses FIRE_CRAWL_API_KEY environment variable if not provided). + api_url: Base URL for FireCrawl API. Defaults to official API endpoint. + mode: Operation mode selection: + - 'crawl': Website crawling mode (default) + - 'scrape': Direct page scraping + - 'map': Site map generation + proxy: Proxy override settings for the FireCrawl API. + params: The parameters to pass to the Firecrawl API. + Examples include crawlerOptions. + For more details, visit: https://github.com/mendableai/firecrawl-py + """ + proxy_server = proxy.get("server") if proxy else None + if trust_env and not proxy_server: + env_proxies = urllib.request.getproxies() + env_proxy_server = env_proxies.get("https") or env_proxies.get("http") + if env_proxy_server: + if proxy: + proxy["server"] = env_proxy_server + else: + proxy = {"server": env_proxy_server} + self.web_paths = web_paths + self.verify_ssl = verify_ssl + self.requests_per_second = requests_per_second + self.last_request_time = None + self.trust_env = trust_env + self.continue_on_failure = continue_on_failure + self.api_key = api_key + self.api_url = api_url + self.mode = mode + self.params = params + + def lazy_load(self) -> Iterator[Document]: + """Load documents concurrently using FireCrawl.""" + for url in self.web_paths: + try: + self._safe_process_url_sync(url) + loader = FireCrawlLoader( + url=url, + api_key=self.api_key, + api_url=self.api_url, + mode=self.mode, + params=self.params, + ) + yield from loader.lazy_load() + except Exception as e: + if self.continue_on_failure: + log.exception(e, "Error loading %s", url) + continue + raise e + + async def alazy_load(self): + """Async version of lazy_load.""" + for url in self.web_paths: + try: + await self._safe_process_url(url) + loader = FireCrawlLoader( + url=url, + api_key=self.api_key, + api_url=self.api_url, + mode=self.mode, + params=self.params, + ) + async for document in loader.alazy_load(): + yield document + except Exception as e: + if self.continue_on_failure: + log.exception(e, "Error loading %s", url) + continue + raise e + + def _verify_ssl_cert(self, url: str) -> bool: + return verify_ssl_cert(url) + + async def _wait_for_rate_limit(self): + """Wait to respect the rate limit if specified.""" + if self.requests_per_second and self.last_request_time: + min_interval = timedelta(seconds=1.0 / self.requests_per_second) + time_since_last = datetime.now() - self.last_request_time + if time_since_last < min_interval: + await asyncio.sleep((min_interval - time_since_last).total_seconds()) + self.last_request_time = datetime.now() + + def _sync_wait_for_rate_limit(self): + """Synchronous version of rate limit wait.""" + if self.requests_per_second and self.last_request_time: + min_interval = timedelta(seconds=1.0 / self.requests_per_second) + time_since_last = datetime.now() - self.last_request_time + if time_since_last < min_interval: + time.sleep((min_interval - time_since_last).total_seconds()) + self.last_request_time = datetime.now() + + async def _safe_process_url(self, url: str) -> bool: + """Perform safety checks before processing a URL.""" + if self.verify_ssl and not self._verify_ssl_cert(url): + raise ValueError(f"SSL certificate verification failed for {url}") + await self._wait_for_rate_limit() + return True + + def _safe_process_url_sync(self, url: str) -> bool: + """Synchronous version of safety checks.""" + if self.verify_ssl and not self._verify_ssl_cert(url): + raise ValueError(f"SSL certificate verification failed for {url}") + self._sync_wait_for_rate_limit() + return True + + +class SafePlaywrightURLLoader(PlaywrightURLLoader): + """Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection. + + Attributes: + web_paths (List[str]): List of URLs to load. + verify_ssl (bool): If True, verify SSL certificates. + trust_env (bool): If True, use proxy settings from environment variables. + requests_per_second (Optional[float]): Number of requests per second to limit to. + continue_on_failure (bool): If True, continue loading other URLs on failure. + headless (bool): If True, the browser will run in headless mode. + proxy (dict): Proxy override settings for the Playwright session. + playwright_ws_url (Optional[str]): WebSocket endpoint URI for remote browser connection. + """ + + def __init__( + self, + web_paths: List[str], + verify_ssl: bool = True, + trust_env: bool = False, + requests_per_second: Optional[float] = None, + continue_on_failure: bool = True, + headless: bool = True, + remove_selectors: Optional[List[str]] = None, + proxy: Optional[Dict[str, str]] = None, + playwright_ws_url: Optional[str] = None, + ): + """Initialize with additional safety parameters and remote browser support.""" + + proxy_server = proxy.get("server") if proxy else None + if trust_env and not proxy_server: + env_proxies = urllib.request.getproxies() + env_proxy_server = env_proxies.get("https") or env_proxies.get("http") + if env_proxy_server: + if proxy: + proxy["server"] = env_proxy_server + else: + proxy = {"server": env_proxy_server} + + # We'll set headless to False if using playwright_ws_url since it's handled by the remote browser + super().__init__( + urls=web_paths, + continue_on_failure=continue_on_failure, + headless=headless if playwright_ws_url is None else False, + remove_selectors=remove_selectors, + proxy=proxy, + ) + self.verify_ssl = verify_ssl + self.requests_per_second = requests_per_second + self.last_request_time = None + self.playwright_ws_url = playwright_ws_url + self.trust_env = trust_env + + def lazy_load(self) -> Iterator[Document]: + """Safely load URLs synchronously with support for remote browser.""" + from playwright.sync_api import sync_playwright + + with sync_playwright() as p: + # Use remote browser if ws_endpoint is provided, otherwise use local browser + if self.playwright_ws_url: + browser = p.chromium.connect(self.playwright_ws_url) + else: + browser = p.chromium.launch(headless=self.headless, proxy=self.proxy) + + for url in self.urls: + try: + self._safe_process_url_sync(url) + page = browser.new_page() + response = page.goto(url) + if response is None: + raise ValueError(f"page.goto() returned None for url {url}") + + text = self.evaluator.evaluate(page, browser, response) + metadata = {"source": url} + yield Document(page_content=text, metadata=metadata) + except Exception as e: + if self.continue_on_failure: + log.exception(e, "Error loading %s", url) + continue + raise e + browser.close() + + async def alazy_load(self) -> AsyncIterator[Document]: + """Safely load URLs asynchronously with support for remote browser.""" + from playwright.async_api import async_playwright + + async with async_playwright() as p: + # Use remote browser if ws_endpoint is provided, otherwise use local browser + if self.playwright_ws_url: + browser = await p.chromium.connect(self.playwright_ws_url) + else: + browser = await p.chromium.launch( + headless=self.headless, proxy=self.proxy + ) + + for url in self.urls: + try: + await self._safe_process_url(url) + page = await browser.new_page() + response = await page.goto(url) + if response is None: + raise ValueError(f"page.goto() returned None for url {url}") + + text = await self.evaluator.evaluate_async(page, browser, response) + metadata = {"source": url} + yield Document(page_content=text, metadata=metadata) + except Exception as e: + if self.continue_on_failure: + log.exception(e, "Error loading %s", url) + continue + raise e + await browser.close() + + def _verify_ssl_cert(self, url: str) -> bool: + return verify_ssl_cert(url) + + async def _wait_for_rate_limit(self): + """Wait to respect the rate limit if specified.""" + if self.requests_per_second and self.last_request_time: + min_interval = timedelta(seconds=1.0 / self.requests_per_second) + time_since_last = datetime.now() - self.last_request_time + if time_since_last < min_interval: + await asyncio.sleep((min_interval - time_since_last).total_seconds()) + self.last_request_time = datetime.now() + + def _sync_wait_for_rate_limit(self): + """Synchronous version of rate limit wait.""" + if self.requests_per_second and self.last_request_time: + min_interval = timedelta(seconds=1.0 / self.requests_per_second) + time_since_last = datetime.now() - self.last_request_time + if time_since_last < min_interval: + time.sleep((min_interval - time_since_last).total_seconds()) + self.last_request_time = datetime.now() + + async def _safe_process_url(self, url: str) -> bool: + """Perform safety checks before processing a URL.""" + if self.verify_ssl and not self._verify_ssl_cert(url): + raise ValueError(f"SSL certificate verification failed for {url}") + await self._wait_for_rate_limit() + return True + + def _safe_process_url_sync(self, url: str) -> bool: + """Synchronous version of safety checks.""" + if self.verify_ssl and not self._verify_ssl_cert(url): + raise ValueError(f"SSL certificate verification failed for {url}") + self._sync_wait_for_rate_limit() + return True + + class SafeWebBaseLoader(WebBaseLoader): """WebBaseLoader with enhanced error handling for URLs.""" @@ -143,20 +467,12 @@ class SafeWebBaseLoader(WebBaseLoader): text = soup.get_text(**self.bs_get_text_kwargs) # Build metadata - metadata = {"source": path} - if title := soup.find("title"): - metadata["title"] = title.get_text() - if description := soup.find("meta", attrs={"name": "description"}): - metadata["description"] = description.get( - "content", "No description found." - ) - if html := soup.find("html"): - metadata["language"] = html.get("lang", "No language found.") + metadata = extract_metadata(soup, path) yield Document(page_content=text, metadata=metadata) except Exception as e: # Log the error and continue with the next URL - log.error(f"Error loading {path}: {e}") + log.exception(e, "Error loading %s", path) async def alazy_load(self) -> AsyncIterator[Document]: """Async lazy load text from the url(s) in web_path.""" @@ -179,6 +495,12 @@ class SafeWebBaseLoader(WebBaseLoader): return [document async for document in self.alazy_load()] +RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader) +RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader +RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader +RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader + + def get_web_loader( urls: Union[str, Sequence[str]], verify_ssl: bool = True, @@ -188,10 +510,29 @@ def get_web_loader( # Check if the URLs are valid safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls) - return SafeWebBaseLoader( - web_path=safe_urls, - verify_ssl=verify_ssl, - requests_per_second=requests_per_second, - continue_on_failure=True, - trust_env=trust_env, + web_loader_args = { + "web_paths": safe_urls, + "verify_ssl": verify_ssl, + "requests_per_second": requests_per_second, + "continue_on_failure": True, + "trust_env": trust_env, + } + + if PLAYWRIGHT_WS_URI.value: + web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URI.value + + if RAG_WEB_LOADER_ENGINE.value == "firecrawl": + web_loader_args["api_key"] = FIRECRAWL_API_KEY.value + web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value + + # Create the appropriate WebLoader based on the configuration + WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value] + web_loader = WebLoaderClass(**web_loader_args) + + log.debug( + "Using RAG_WEB_LOADER_ENGINE %s for %s URLs", + web_loader.__class__.__name__, + len(safe_urls), ) + + return web_loader diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index e2d05ba908..c949e65a46 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -37,6 +37,7 @@ from open_webui.config import ( from open_webui.constants import ERROR_MESSAGES from open_webui.env import ( + AIOHTTP_CLIENT_TIMEOUT, ENV, SRC_LOG_LEVELS, DEVICE_TYPE, @@ -70,7 +71,7 @@ from pydub.utils import mediainfo def is_mp4_audio(file_path): """Check if the given file is an MP4 audio file.""" if not os.path.isfile(file_path): - print(f"File not found: {file_path}") + log.error(f"File not found: {file_path}") return False info = mediainfo(file_path) @@ -87,7 +88,7 @@ def convert_mp4_to_wav(file_path, output_path): """Convert MP4 audio file to WAV format.""" audio = AudioSegment.from_file(file_path, format="mp4") audio.export(output_path, format="wav") - print(f"Converted {file_path} to {output_path}") + log.info(f"Converted {file_path} to {output_path}") def set_faster_whisper_model(model: str, auto_update: bool = False): @@ -265,8 +266,10 @@ async def speech(request: Request, user=Depends(get_verified_user)): payload["model"] = request.app.state.config.TTS_MODEL try: - # print(payload) - async with aiohttp.ClientSession() as session: + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + async with aiohttp.ClientSession( + timeout=timeout, trust_env=True + ) as session: async with session.post( url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", json=payload, @@ -323,7 +326,10 @@ async def speech(request: Request, user=Depends(get_verified_user)): ) try: - async with aiohttp.ClientSession() as session: + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + async with aiohttp.ClientSession( + timeout=timeout, trust_env=True + ) as session: async with session.post( f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}", json={ @@ -380,7 +386,10 @@ async def speech(request: Request, user=Depends(get_verified_user)): data = f""" {payload["input"]} """ - async with aiohttp.ClientSession() as session: + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + async with aiohttp.ClientSession( + timeout=timeout, trust_env=True + ) as session: async with session.post( f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1", headers={ @@ -458,7 +467,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): def transcribe(request: Request, file_path): - print("transcribe", file_path) + log.info(f"transcribe: {file_path}") filename = os.path.basename(file_path) file_dir = os.path.dirname(file_path) id = filename.split(".")[0] @@ -670,7 +679,22 @@ def transcription( def get_available_models(request: Request) -> list[dict]: available_models = [] if request.app.state.config.TTS_ENGINE == "openai": - available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}] + # Use custom endpoint if not using the official OpenAI API URL + if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith( + "https://api.openai.com" + ): + try: + response = requests.get( + f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models" + ) + response.raise_for_status() + data = response.json() + available_models = data.get("models", []) + except Exception as e: + log.error(f"Error fetching models from custom endpoint: {str(e)}") + available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}] + else: + available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}] elif request.app.state.config.TTS_ENGINE == "elevenlabs": try: response = requests.get( @@ -701,14 +725,37 @@ def get_available_voices(request) -> dict: """Returns {voice_id: voice_name} dict""" available_voices = {} if request.app.state.config.TTS_ENGINE == "openai": - available_voices = { - "alloy": "alloy", - "echo": "echo", - "fable": "fable", - "onyx": "onyx", - "nova": "nova", - "shimmer": "shimmer", - } + # Use custom endpoint if not using the official OpenAI API URL + if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith( + "https://api.openai.com" + ): + try: + response = requests.get( + f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices" + ) + response.raise_for_status() + data = response.json() + voices_list = data.get("voices", []) + available_voices = {voice["id"]: voice["name"] for voice in voices_list} + except Exception as e: + log.error(f"Error fetching voices from custom endpoint: {str(e)}") + available_voices = { + "alloy": "alloy", + "echo": "echo", + "fable": "fable", + "onyx": "onyx", + "nova": "nova", + "shimmer": "shimmer", + } + else: + available_voices = { + "alloy": "alloy", + "echo": "echo", + "fable": "fable", + "onyx": "onyx", + "nova": "nova", + "shimmer": "shimmer", + } elif request.app.state.config.TTS_ENGINE == "elevenlabs": try: available_voices = get_elevenlabs_voices( diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 16db98ff5b..f01b5bd747 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -31,10 +31,7 @@ from open_webui.env import ( ) from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import RedirectResponse, Response -from open_webui.config import ( - OPENID_PROVIDER_URL, - ENABLE_OAUTH_SIGNUP, -) +from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP from pydantic import BaseModel from open_webui.utils.misc import parse_duration, validate_email_format from open_webui.utils.auth import ( @@ -51,8 +48,10 @@ from open_webui.utils.access_control import get_permissions from typing import Optional, List from ssl import CERT_REQUIRED, PROTOCOL_TLS -from ldap3 import Server, Connection, NONE, Tls -from ldap3.utils.conv import escape_filter_chars + +if ENABLE_LDAP.value: + from ldap3 import Server, Connection, NONE, Tls + from ldap3.utils.conv import escape_filter_chars router = APIRouter() @@ -252,14 +251,6 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): if not user: try: user_count = Users.get_num_users() - if ( - request.app.state.USER_COUNT - and user_count >= request.app.state.USER_COUNT - ): - raise HTTPException( - status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) role = ( "admin" @@ -423,7 +414,6 @@ async def signin(request: Request, response: Response, form_data: SigninForm): @router.post("/signup", response_model=SessionUserResponse) async def signup(request: Request, response: Response, form_data: SignupForm): - user_count = Users.get_num_users() if WEBUI_AUTH: if ( @@ -434,16 +424,12 @@ async def signup(request: Request, response: Response, form_data: SignupForm): status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED ) else: - if user_count != 0: + if Users.get_num_users() != 0: raise HTTPException( status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED ) - if request.app.state.USER_COUNT and user_count >= request.app.state.USER_COUNT: - raise HTTPException( - status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED - ) - + user_count = Users.get_num_users() if not validate_email_format(form_data.email.lower()): raise HTTPException( status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT @@ -546,7 +532,8 @@ async def signout(request: Request, response: Response): if logout_url: response.delete_cookie("oauth_id_token") return RedirectResponse( - url=f"{logout_url}?id_token_hint={oauth_id_token}" + headers=response.headers, + url=f"{logout_url}?id_token_hint={oauth_id_token}", ) else: raise HTTPException( @@ -612,7 +599,7 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)): admin_email = request.app.state.config.ADMIN_EMAIL admin_name = None - print(admin_email, admin_name) + log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}") if admin_email: admin = Users.get_user_by_email(admin_email) diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index 016075234a..388c44f9c6 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -70,6 +70,12 @@ async def set_direct_connections_config( # CodeInterpreterConfig ############################ class CodeInterpreterConfigForm(BaseModel): + CODE_EXECUTION_ENGINE: str + CODE_EXECUTION_JUPYTER_URL: Optional[str] + CODE_EXECUTION_JUPYTER_AUTH: Optional[str] + CODE_EXECUTION_JUPYTER_AUTH_TOKEN: Optional[str] + CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: Optional[str] + CODE_EXECUTION_JUPYTER_TIMEOUT: Optional[int] ENABLE_CODE_INTERPRETER: bool CODE_INTERPRETER_ENGINE: str CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str] @@ -77,11 +83,18 @@ class CodeInterpreterConfigForm(BaseModel): CODE_INTERPRETER_JUPYTER_AUTH: Optional[str] CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str] CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str] + CODE_INTERPRETER_JUPYTER_TIMEOUT: Optional[int] -@router.get("/code_interpreter", response_model=CodeInterpreterConfigForm) -async def get_code_interpreter_config(request: Request, user=Depends(get_admin_user)): +@router.get("/code_execution", response_model=CodeInterpreterConfigForm) +async def get_code_execution_config(request: Request, user=Depends(get_admin_user)): return { + "CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE, + "CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL, + "CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH, + "CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN, + "CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD, + "CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT, "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER, "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE, "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE, @@ -89,13 +102,32 @@ async def get_code_interpreter_config(request: Request, user=Depends(get_admin_u "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH, "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, + "CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT, } -@router.post("/code_interpreter", response_model=CodeInterpreterConfigForm) -async def set_code_interpreter_config( +@router.post("/code_execution", response_model=CodeInterpreterConfigForm) +async def set_code_execution_config( request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user) ): + + request.app.state.config.CODE_EXECUTION_ENGINE = form_data.CODE_EXECUTION_ENGINE + request.app.state.config.CODE_EXECUTION_JUPYTER_URL = ( + form_data.CODE_EXECUTION_JUPYTER_URL + ) + request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH = ( + form_data.CODE_EXECUTION_JUPYTER_AUTH + ) + request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = ( + form_data.CODE_EXECUTION_JUPYTER_AUTH_TOKEN + ) + request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = ( + form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD + ) + request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = ( + form_data.CODE_EXECUTION_JUPYTER_TIMEOUT + ) + request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = ( @@ -116,8 +148,17 @@ async def set_code_interpreter_config( request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = ( form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD ) + request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = ( + form_data.CODE_INTERPRETER_JUPYTER_TIMEOUT + ) return { + "CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE, + "CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL, + "CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH, + "CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN, + "CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD, + "CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT, "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER, "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE, "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE, @@ -125,6 +166,7 @@ async def set_code_interpreter_config( "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH, "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, + "CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT, } diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 0513212571..95b7f6461a 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -16,6 +16,7 @@ from open_webui.models.files import ( Files, ) from open_webui.routers.retrieval import ProcessFileForm, process_file +from open_webui.routers.audio import transcribe from open_webui.storage.provider import Storage from open_webui.utils.auth import get_admin_user, get_verified_user from pydantic import BaseModel @@ -67,7 +68,22 @@ def upload_file( ) try: - process_file(request, ProcessFileForm(file_id=id), user=user) + if file.content_type in [ + "audio/mpeg", + "audio/wav", + "audio/ogg", + "audio/x-m4a", + ]: + file_path = Storage.get_file(file_path) + result = transcribe(request, file_path) + process_file( + request, + ProcessFileForm(file_id=id, content=result.get("text", "")), + user=user, + ) + else: + process_file(request, ProcessFileForm(file_id=id), user=user) + file_item = Files.get_file_by_id(id=id) except Exception as e: log.exception(e) @@ -225,17 +241,24 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): filename = file.meta.get("name", file.filename) encoded_filename = quote(filename) # RFC5987 encoding + content_type = file.meta.get("content_type") + filename = file.meta.get("name", file.filename) + encoded_filename = quote(filename) headers = {} - if file.meta.get("content_type") not in [ - "application/pdf", - "text/plain", - ]: - headers = { - **headers, - "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}", - } - return FileResponse(file_path, headers=headers) + if content_type == "application/pdf" or filename.lower().endswith( + ".pdf" + ): + headers["Content-Disposition"] = ( + f"inline; filename*=UTF-8''{encoded_filename}" + ) + content_type = "application/pdf" + elif content_type != "text/plain": + headers["Content-Disposition"] = ( + f"attachment; filename*=UTF-8''{encoded_filename}" + ) + + return FileResponse(file_path, headers=headers, media_type=content_type) else: raise HTTPException( @@ -266,7 +289,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)): # Check if the file already exists in the cache if file_path.is_file(): - print(f"file_path: {file_path}") + log.info(f"file_path: {file_path}") return FileResponse(file_path) else: raise HTTPException( diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py index 7f3305f25a..ac2db9322a 100644 --- a/backend/open_webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -1,4 +1,5 @@ import os +import logging from pathlib import Path from typing import Optional @@ -13,6 +14,11 @@ from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + router = APIRouter() @@ -79,7 +85,7 @@ async def create_new_function( detail=ERROR_MESSAGES.DEFAULT("Error creating function"), ) except Exception as e: - print(e) + log.exception(f"Failed to create a new function: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), @@ -183,7 +189,7 @@ async def update_function_by_id( FUNCTIONS[id] = function_module updated = {**form_data.model_dump(exclude={"id"}), "type": function_type} - print(updated) + log.debug(updated) function = Functions.update_function_by_id(id, updated) @@ -299,7 +305,7 @@ async def update_function_valves_by_id( Functions.update_function_valves_by_id(id, valves.model_dump()) return valves.model_dump() except Exception as e: - print(e) + log.exception(f"Error updating function values by id {id}: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), @@ -388,7 +394,7 @@ async def update_function_user_valves_by_id( ) return user_valves.model_dump() except Exception as e: - print(e) + log.exception(f"Error updating function user valves by id {id}: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), diff --git a/backend/open_webui/routers/groups.py b/backend/open_webui/routers/groups.py old mode 100644 new mode 100755 index 5b5130f71d..ae822c0d00 --- a/backend/open_webui/routers/groups.py +++ b/backend/open_webui/routers/groups.py @@ -1,7 +1,7 @@ import os from pathlib import Path from typing import Optional - +import logging from open_webui.models.users import Users from open_webui.models.groups import ( @@ -14,7 +14,13 @@ from open_webui.models.groups import ( from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status + from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.env import SRC_LOG_LEVELS + + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) router = APIRouter() @@ -37,7 +43,7 @@ async def get_groups(user=Depends(get_verified_user)): @router.post("/create", response_model=Optional[GroupResponse]) -async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)): +async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)): try: group = Groups.insert_new_group(user.id, form_data) if group: @@ -48,7 +54,7 @@ async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user) detail=ERROR_MESSAGES.DEFAULT("Error creating group"), ) except Exception as e: - print(e) + log.exception(f"Error creating a new group: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), @@ -94,7 +100,7 @@ async def update_group_by_id( detail=ERROR_MESSAGES.DEFAULT("Error updating group"), ) except Exception as e: - print(e) + log.exception(f"Error updating group {id}: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), @@ -118,7 +124,7 @@ async def delete_group_by_id(id: str, user=Depends(get_admin_user)): detail=ERROR_MESSAGES.DEFAULT("Error deleting group"), ) except Exception as e: - print(e) + log.exception(f"Error deleting group {id}: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 4046773dea..131fa2df4d 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -55,6 +55,10 @@ async def get_config(request: Request, user=Depends(get_admin_user)): "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, }, + "gemini": { + "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL, + "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY, + }, } @@ -78,6 +82,11 @@ class ComfyUIConfigForm(BaseModel): COMFYUI_WORKFLOW_NODES: list[dict] +class GeminiConfigForm(BaseModel): + GEMINI_API_BASE_URL: str + GEMINI_API_KEY: str + + class ConfigForm(BaseModel): enabled: bool engine: str @@ -85,6 +94,7 @@ class ConfigForm(BaseModel): openai: OpenAIConfigForm automatic1111: Automatic1111ConfigForm comfyui: ComfyUIConfigForm + gemini: GeminiConfigForm @router.post("/config/update") @@ -103,6 +113,11 @@ async def update_config( ) request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY + request.app.state.config.IMAGES_GEMINI_API_BASE_URL = ( + form_data.gemini.GEMINI_API_BASE_URL + ) + request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY + request.app.state.config.AUTOMATIC1111_BASE_URL = ( form_data.automatic1111.AUTOMATIC1111_BASE_URL ) @@ -129,6 +144,8 @@ async def update_config( request.app.state.config.COMFYUI_BASE_URL = ( form_data.comfyui.COMFYUI_BASE_URL.strip("/") ) + request.app.state.config.COMFYUI_API_KEY = form_data.comfyui.COMFYUI_API_KEY + request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW request.app.state.config.COMFYUI_WORKFLOW_NODES = ( form_data.comfyui.COMFYUI_WORKFLOW_NODES @@ -155,6 +172,10 @@ async def update_config( "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW, "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES, }, + "gemini": { + "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL, + "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY, + }, } @@ -184,9 +205,17 @@ async def verify_url(request: Request, user=Depends(get_admin_user)): request.app.state.config.ENABLE_IMAGE_GENERATION = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": + + headers = None + if request.app.state.config.COMFYUI_API_KEY: + headers = { + "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}" + } + try: r = requests.get( - url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info" + url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info", + headers=headers, ) r.raise_for_status() return True @@ -224,6 +253,12 @@ def get_image_model(request): if request.app.state.config.IMAGE_GENERATION_MODEL else "dall-e-2" ) + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini": + return ( + request.app.state.config.IMAGE_GENERATION_MODEL + if request.app.state.config.IMAGE_GENERATION_MODEL + else "imagen-3.0-generate-002" + ) elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": return ( request.app.state.config.IMAGE_GENERATION_MODEL @@ -299,6 +334,10 @@ def get_models(request: Request, user=Depends(get_verified_user)): {"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-3", "name": "DALL·E 3"}, ] + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini": + return [ + {"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"}, + ] elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": # TODO - get models from comfyui headers = { @@ -322,7 +361,7 @@ def get_models(request: Request, user=Depends(get_verified_user)): if model_node_id: model_list_key = None - print(workflow[model_node_id]["class_type"]) + log.info(workflow[model_node_id]["class_type"]) for key in info[workflow[model_node_id]["class_type"]]["input"][ "required" ]: @@ -483,6 +522,41 @@ async def image_generations( images.append({"url": url}) return images + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini": + headers = {} + headers["Content-Type"] = "application/json" + headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY + + model = get_image_model(request) + data = { + "instances": {"prompt": form_data.prompt}, + "parameters": { + "sampleCount": form_data.n, + "outputOptions": {"mimeType": "image/png"}, + }, + } + + # Use asyncio.to_thread for the requests.post call + r = await asyncio.to_thread( + requests.post, + url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict", + json=data, + headers=headers, + ) + + r.raise_for_status() + res = r.json() + + images = [] + for image in res["predictions"]: + image_data, content_type = load_b64_image_data( + image["bytesBase64Encoded"] + ) + url = upload_image(request, data, image_data, content_type, user) + images.append({"url": url}) + + return images + elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": data = { "prompt": form_data.prompt, diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 0ba6191a2a..1969045505 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -614,7 +614,7 @@ def add_files_to_knowledge_batch( ) # Get files content - print(f"files/batch/add - {len(form_data)} files") + log.info(f"files/batch/add - {len(form_data)} files") files: List[FileModel] = [] for form in form_data: file = Files.get_file_by_id(form.file_id) diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 4fca10e1f5..d99416c832 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -14,6 +14,11 @@ from urllib.parse import urlparse import aiohttp from aiocache import cached import requests +from open_webui.models.users import UserModel + +from open_webui.env import ( + ENABLE_FORWARD_USER_INFO_HEADERS, +) from fastapi import ( Depends, @@ -26,7 +31,7 @@ from fastapi import ( ) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, validator from starlette.background import BackgroundTask @@ -66,12 +71,26 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) ########################################## -async def send_get_request(url, key=None): +async def send_get_request(url, key=None, user: UserModel = None): timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) try: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.get( - url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} + url, + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, ) as response: return await response.json() except Exception as e: @@ -96,6 +115,7 @@ async def send_post_request( stream: bool = True, key: Optional[str] = None, content_type: Optional[str] = None, + user: UserModel = None, ): r = None @@ -110,6 +130,16 @@ async def send_post_request( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, ) r.raise_for_status() @@ -191,7 +221,19 @@ async def verify_connection( try: async with session.get( f"{url}/api/version", - headers={**({"Authorization": f"Bearer {key}"} if key else {})}, + headers={ + **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, ) as r: if r.status != 200: detail = f"HTTP Error: {r.status}" @@ -254,7 +296,7 @@ async def update_config( @cached(ttl=3) -async def get_all_models(request: Request): +async def get_all_models(request: Request, user: UserModel = None): log.info("get_all_models()") if request.app.state.config.ENABLE_OLLAMA_API: request_tasks = [] @@ -262,7 +304,7 @@ async def get_all_models(request: Request): if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and ( url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support ): - request_tasks.append(send_get_request(f"{url}/api/tags")) + request_tasks.append(send_get_request(f"{url}/api/tags", user=user)) else: api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( str(idx), @@ -275,7 +317,9 @@ async def get_all_models(request: Request): key = api_config.get("key", None) if enable: - request_tasks.append(send_get_request(f"{url}/api/tags", key)) + request_tasks.append( + send_get_request(f"{url}/api/tags", key, user=user) + ) else: request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) @@ -360,7 +404,7 @@ async def get_ollama_tags( models = [] if url_idx is None: - models = await get_all_models(request) + models = await get_all_models(request, user=user) else: url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) @@ -370,7 +414,19 @@ async def get_ollama_tags( r = requests.request( method="GET", url=f"{url}/api/tags", - headers={**({"Authorization": f"Bearer {key}"} if key else {})}, + headers={ + **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, ) r.raise_for_status() @@ -477,6 +533,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u url, {} ), # Legacy support ).get("key", None), + user=user, ) for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS) ] @@ -509,6 +566,7 @@ async def pull_model( url=f"{url}/api/pull", payload=json.dumps(payload), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -527,7 +585,7 @@ async def push_model( user=Depends(get_admin_user), ): if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if form_data.name in models: @@ -545,6 +603,7 @@ async def push_model( url=f"{url}/api/push", payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -571,6 +630,7 @@ async def create_model( url=f"{url}/api/create", payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -588,7 +648,7 @@ async def copy_model( user=Depends(get_admin_user), ): if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if form_data.source in models: @@ -609,6 +669,16 @@ async def copy_model( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -643,7 +713,7 @@ async def delete_model( user=Depends(get_admin_user), ): if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if form_data.name in models: @@ -665,6 +735,16 @@ async def delete_model( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, ) r.raise_for_status() @@ -693,7 +773,7 @@ async def delete_model( async def show_model_info( request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) ): - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if form_data.name not in models: @@ -714,6 +794,16 @@ async def show_model_info( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -757,7 +847,7 @@ async def embed( log.info(f"generate_ollama_batch_embeddings {form_data}") if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -783,6 +873,16 @@ async def embed( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -826,7 +926,7 @@ async def embeddings( log.info(f"generate_ollama_embeddings {form_data}") if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -852,6 +952,16 @@ async def embeddings( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -901,7 +1011,7 @@ async def generate_completion( user=Depends(get_verified_user), ): if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -931,15 +1041,29 @@ async def generate_completion( url=f"{url}/api/generate", payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) class ChatMessage(BaseModel): role: str - content: str + content: Optional[str] = None tool_calls: Optional[list[dict]] = None images: Optional[list[str]] = None + @validator("content", pre=True) + @classmethod + def check_at_least_one_field(cls, field_value, values, **kwargs): + # Raise an error if both 'content' and 'tool_calls' are None + if field_value is None and ( + "tool_calls" not in values or values["tool_calls"] is None + ): + raise ValueError( + "At least one of 'content' or 'tool_calls' must be provided" + ) + + return field_value + class GenerateChatCompletionForm(BaseModel): model: str @@ -1047,6 +1171,7 @@ async def generate_chat_completion( stream=form_data.stream, key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), content_type="application/x-ndjson", + user=user, ) @@ -1149,6 +1274,7 @@ async def generate_openai_completion( payload=json.dumps(payload), stream=payload.get("stream", False), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -1227,6 +1353,7 @@ async def generate_openai_chat_completion( payload=json.dumps(payload), stream=payload.get("stream", False), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -1240,7 +1367,7 @@ async def get_openai_models( models = [] if url_idx is None: - model_list = await get_all_models(request) + model_list = await get_all_models(request, user=user) models = [ { "id": model["model"], diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index afda362373..990df83b0b 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -26,6 +26,7 @@ from open_webui.env import ( ENABLE_FORWARD_USER_INFO_HEADERS, BYPASS_MODEL_ACCESS_CONTROL, ) +from open_webui.models.users import UserModel from open_webui.constants import ERROR_MESSAGES from open_webui.env import ENV, SRC_LOG_LEVELS @@ -51,12 +52,25 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"]) ########################################## -async def send_get_request(url, key=None): +async def send_get_request(url, key=None, user: UserModel = None): timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) try: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.get( - url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} + url, + headers={ + **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, ) as response: return await response.json() except Exception as e: @@ -84,9 +98,15 @@ def openai_o1_o3_handler(payload): payload["max_completion_tokens"] = payload["max_tokens"] del payload["max_tokens"] - # Fix: O1 does not support the "system" parameter, Modify "system" to "user" + # Fix: o1 and o3 do not support the "system" role directly. + # For older models like "o1-mini" or "o1-preview", use role "user". + # For newer o1/o3 models, replace "system" with "developer". if payload["messages"][0]["role"] == "system": - payload["messages"][0]["role"] = "user" + model_lower = payload["model"].lower() + if model_lower.startswith("o1-mini") or model_lower.startswith("o1-preview"): + payload["messages"][0]["role"] = "user" + else: + payload["messages"][0]["role"] = "developer" return payload @@ -247,7 +267,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) -async def get_all_models_responses(request: Request) -> list: +async def get_all_models_responses(request: Request, user: UserModel) -> list: if not request.app.state.config.ENABLE_OPENAI_API: return [] @@ -271,7 +291,9 @@ async def get_all_models_responses(request: Request) -> list: ): request_tasks.append( send_get_request( - f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx] + f"{url}/models", + request.app.state.config.OPENAI_API_KEYS[idx], + user=user, ) ) else: @@ -291,6 +313,7 @@ async def get_all_models_responses(request: Request) -> list: send_get_request( f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx], + user=user, ) ) else: @@ -352,13 +375,13 @@ async def get_filtered_models(models, user): @cached(ttl=3) -async def get_all_models(request: Request) -> dict[str, list]: +async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: log.info("get_all_models()") if not request.app.state.config.ENABLE_OPENAI_API: return {"data": []} - responses = await get_all_models_responses(request) + responses = await get_all_models_responses(request, user=user) def extract_data(response): if response and "data" in response: @@ -418,7 +441,7 @@ async def get_models( } if url_idx is None: - models = await get_all_models(request) + models = await get_all_models(request, user=user) else: url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx] key = request.app.state.config.OPENAI_API_KEYS[url_idx] @@ -515,6 +538,16 @@ async def verify_connection( headers={ "Authorization": f"Bearer {key}", "Content-Type": "application/json", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), }, ) as r: if r.status != 200: @@ -587,7 +620,7 @@ async def generate_chat_completion( detail="Model not found", ) - await get_all_models(request) + await get_all_models(request, user=user) model = request.app.state.OPENAI_MODELS.get(model_id) if model: idx = model["urlIdx"] @@ -777,7 +810,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): if r is not None: try: res = await r.json() - print(res) + log.error(res) if "error" in res: detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" except Exception: diff --git a/backend/open_webui/routers/pipelines.py b/backend/open_webui/routers/pipelines.py index ad280b65c1..599208e43d 100644 --- a/backend/open_webui/routers/pipelines.py +++ b/backend/open_webui/routers/pipelines.py @@ -101,7 +101,7 @@ async def process_pipeline_inlet_filter(request, payload, user, models): if "detail" in res: raise Exception(response.status, res["detail"]) except Exception as e: - print(f"Connection error: {e}") + log.exception(f"Connection error: {e}") return payload @@ -153,7 +153,7 @@ async def process_pipeline_outlet_filter(request, payload, user, models): except Exception: pass except Exception as e: - print(f"Connection error: {e}") + log.exception(f"Connection error: {e}") return payload @@ -169,7 +169,7 @@ router = APIRouter() @router.get("/list") async def get_pipelines_list(request: Request, user=Depends(get_admin_user)): - responses = await get_all_models_responses(request) + responses = await get_all_models_responses(request, user) log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}") urlIdxs = [ @@ -196,7 +196,7 @@ async def upload_pipeline( file: UploadFile = File(...), user=Depends(get_admin_user), ): - print("upload_pipeline", urlIdx, file.filename) + log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}") # Check if the uploaded file is a python file if not (file.filename and file.filename.endswith(".py")): raise HTTPException( @@ -231,7 +231,7 @@ async def upload_pipeline( return {**data} except Exception as e: # Handle connection error here - print(f"Connection error: {e}") + log.exception(f"Connection error: {e}") detail = None status_code = status.HTTP_404_NOT_FOUND @@ -282,7 +282,7 @@ async def add_pipeline( return {**data} except Exception as e: # Handle connection error here - print(f"Connection error: {e}") + log.exception(f"Connection error: {e}") detail = None if r is not None: @@ -327,7 +327,7 @@ async def delete_pipeline( return {**data} except Exception as e: # Handle connection error here - print(f"Connection error: {e}") + log.exception(f"Connection error: {e}") detail = None if r is not None: @@ -361,7 +361,7 @@ async def get_pipelines( return {**data} except Exception as e: # Handle connection error here - print(f"Connection error: {e}") + log.exception(f"Connection error: {e}") detail = None if r is not None: @@ -400,7 +400,7 @@ async def get_pipeline_valves( return {**data} except Exception as e: # Handle connection error here - print(f"Connection error: {e}") + log.exception(f"Connection error: {e}") detail = None if r is not None: @@ -440,7 +440,7 @@ async def get_pipeline_valves_spec( return {**data} except Exception as e: # Handle connection error here - print(f"Connection error: {e}") + log.exception(f"Connection error: {e}") detail = None if r is not None: @@ -482,7 +482,7 @@ async def update_pipeline_valves( return {**data} except Exception as e: # Handle connection error here - print(f"Connection error: {e}") + log.exception(f"Connection error: {e}") detail = None diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 47e6253c88..7dd324b80b 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -351,10 +351,17 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): return { "status": True, "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, + "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT, + "BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL, "enable_google_drive_integration": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, + "enable_onedrive_integration": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION, "content_extraction": { "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE, "tika_server_url": request.app.state.config.TIKA_SERVER_URL, + "document_intelligence_config": { + "endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, + "key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, + }, }, "chunk": { "text_splitter": request.app.state.config.TEXT_SPLITTER, @@ -371,10 +378,12 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, }, "web": { - "web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, "search": { "enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH, "drive": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, + "onedrive": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION, "engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE, "searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL, "google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY, @@ -397,6 +406,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, "exa_api_key": request.app.state.config.EXA_API_KEY, "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + "trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV, "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, "domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, }, @@ -409,9 +419,15 @@ class FileConfig(BaseModel): max_count: Optional[int] = None +class DocumentIntelligenceConfigForm(BaseModel): + endpoint: str + key: str + + class ContentExtractionConfig(BaseModel): engine: str = "" tika_server_url: Optional[str] = None + document_intelligence_config: Optional[DocumentIntelligenceConfigForm] = None class ChunkParamUpdateForm(BaseModel): @@ -457,12 +473,16 @@ class WebSearchConfig(BaseModel): class WebConfig(BaseModel): search: WebSearchConfig - web_loader_ssl_verification: Optional[bool] = None + ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION: Optional[bool] = None + BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None class ConfigUpdateForm(BaseModel): + RAG_FULL_CONTEXT: Optional[bool] = None + BYPASS_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None pdf_extract_images: Optional[bool] = None enable_google_drive_integration: Optional[bool] = None + enable_onedrive_integration: Optional[bool] = None file: Optional[FileConfig] = None content_extraction: Optional[ContentExtractionConfig] = None chunk: Optional[ChunkParamUpdateForm] = None @@ -480,24 +500,51 @@ async def update_rag_config( else request.app.state.config.PDF_EXTRACT_IMAGES ) + request.app.state.config.RAG_FULL_CONTEXT = ( + form_data.RAG_FULL_CONTEXT + if form_data.RAG_FULL_CONTEXT is not None + else request.app.state.config.RAG_FULL_CONTEXT + ) + + request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = ( + form_data.BYPASS_EMBEDDING_AND_RETRIEVAL + if form_data.BYPASS_EMBEDDING_AND_RETRIEVAL is not None + else request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL + ) + request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ( form_data.enable_google_drive_integration if form_data.enable_google_drive_integration is not None else request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION ) + request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ( + form_data.enable_onedrive_integration + if form_data.enable_onedrive_integration is not None + else request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION + ) + if form_data.file is not None: request.app.state.config.FILE_MAX_SIZE = form_data.file.max_size request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count if form_data.content_extraction is not None: - log.info(f"Updating text settings: {form_data.content_extraction}") + log.info( + f"Updating content extraction: {request.app.state.config.CONTENT_EXTRACTION_ENGINE} to {form_data.content_extraction.engine}" + ) request.app.state.config.CONTENT_EXTRACTION_ENGINE = ( form_data.content_extraction.engine ) request.app.state.config.TIKA_SERVER_URL = ( form_data.content_extraction.tika_server_url ) + if form_data.content_extraction.document_intelligence_config is not None: + request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = ( + form_data.content_extraction.document_intelligence_config.endpoint + ) + request.app.state.config.DOCUMENT_INTELLIGENCE_KEY = ( + form_data.content_extraction.document_intelligence_config.key + ) if form_data.chunk is not None: request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter @@ -512,11 +559,16 @@ async def update_rag_config( if form_data.web is not None: request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( # Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False - form_data.web.web_loader_ssl_verification + form_data.web.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION ) request.app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled request.app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine + + request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = ( + form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL + ) + request.app.state.config.SEARXNG_QUERY_URL = ( form_data.web.search.searxng_query_url ) @@ -581,6 +633,8 @@ async def update_rag_config( return { "status": True, "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES, + "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT, + "BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL, "file": { "max_size": request.app.state.config.FILE_MAX_SIZE, "max_count": request.app.state.config.FILE_MAX_COUNT, @@ -588,6 +642,10 @@ async def update_rag_config( "content_extraction": { "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE, "tika_server_url": request.app.state.config.TIKA_SERVER_URL, + "document_intelligence_config": { + "endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, + "key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, + }, }, "chunk": { "text_splitter": request.app.state.config.TEXT_SPLITTER, @@ -600,7 +658,8 @@ async def update_rag_config( "translation": request.app.state.YOUTUBE_LOADER_TRANSLATION, }, "web": { - "web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, "search": { "enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH, "engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE, @@ -863,7 +922,12 @@ def process_file( # Update the content in the file # Usage: /files/{file_id}/data/content/update - VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}") + try: + # /files/{file_id}/data/content/update + VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}") + except: + # Audio file upload pipeline + pass docs = [ Document( @@ -920,6 +984,8 @@ def process_file( engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE, TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL, PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES, + DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, + DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, ) docs = loader.load( file.filename, file.meta.get("content_type"), file_path @@ -962,36 +1028,45 @@ def process_file( hash = calculate_sha256_string(text_content) Files.update_file_hash_by_id(file.id, hash) - try: - result = save_docs_to_vector_db( - request, - docs=docs, - collection_name=collection_name, - metadata={ - "file_id": file.id, - "name": file.filename, - "hash": hash, - }, - add=(True if form_data.collection_name else False), - user=user, - ) - - if result: - Files.update_file_metadata_by_id( - file.id, - { - "collection_name": collection_name, + if not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL: + try: + result = save_docs_to_vector_db( + request, + docs=docs, + collection_name=collection_name, + metadata={ + "file_id": file.id, + "name": file.filename, + "hash": hash, }, + add=(True if form_data.collection_name else False), + user=user, ) - return { - "status": True, - "collection_name": collection_name, - "filename": file.filename, - "content": text_content, - } - except Exception as e: - raise e + if result: + Files.update_file_metadata_by_id( + file.id, + { + "collection_name": collection_name, + }, + ) + + return { + "status": True, + "collection_name": collection_name, + "filename": file.filename, + "content": text_content, + } + except Exception as e: + raise e + else: + return { + "status": True, + "collection_name": None, + "filename": file.filename, + "content": text_content, + } + except Exception as e: log.exception(e) if "No pandoc was found" in str(e): @@ -1262,6 +1337,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: request.app.state.config.TAVILY_API_KEY, query, request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, + request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No TAVILY_API_KEY found in environment variables") @@ -1349,21 +1425,37 @@ async def process_web_search( trust_env=request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV, ) docs = await loader.aload() - await run_in_threadpool( - save_docs_to_vector_db, - request, - docs, - collection_name, - overwrite=True, - user=user, - ) - return { - "status": True, - "collection_name": collection_name, - "filenames": urls, - "loaded_count": len(docs), - } + if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: + return { + "status": True, + "collection_name": None, + "filenames": urls, + "docs": [ + { + "content": doc.page_content, + "metadata": doc.metadata, + } + for doc in docs + ], + "loaded_count": len(docs), + } + else: + await run_in_threadpool( + save_docs_to_vector_db, + request, + docs, + collection_name, + overwrite=True, + user=user, + ) + + return { + "status": True, + "collection_name": collection_name, + "filenames": urls, + "loaded_count": len(docs), + } except Exception as e: log.exception(e) raise HTTPException( @@ -1520,11 +1612,11 @@ def reset_upload_dir(user=Depends(get_admin_user)) -> bool: elif os.path.isdir(file_path): shutil.rmtree(file_path) # Remove the directory except Exception as e: - print(f"Failed to delete {file_path}. Reason: {e}") + log.exception(f"Failed to delete {file_path}. Reason: {e}") else: - print(f"The directory {folder} does not exist") + log.warning(f"The directory {folder} does not exist") except Exception as e: - print(f"Failed to process the directory {folder}. Reason: {e}") + log.exception(f"Failed to process the directory {folder}. Reason: {e}") return True diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index 8b17c6c4b3..b63c9732af 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -20,6 +20,10 @@ from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.constants import TASKS from open_webui.routers.pipelines import process_pipeline_inlet_filter +from open_webui.utils.filter import ( + get_sorted_filter_ids, + process_filter_functions, +) from open_webui.utils.task import get_task_model_id from open_webui.config import ( @@ -208,7 +212,7 @@ async def generate_title( "stream": False, **( {"max_tokens": 1000} - if models[task_model_id]["owned_by"] == "ollama" + if models[task_model_id].get("owned_by") == "ollama" else { "max_completion_tokens": 1000, } @@ -221,6 +225,12 @@ async def generate_title( }, } + # Process the payload through the pipeline + try: + payload = await process_pipeline_inlet_filter(request, payload, user, models) + except Exception as e: + raise e + try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: @@ -290,6 +300,12 @@ async def generate_chat_tags( }, } + # Process the payload through the pipeline + try: + payload = await process_pipeline_inlet_filter(request, payload, user, models) + except Exception as e: + raise e + try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: @@ -356,6 +372,12 @@ async def generate_image_prompt( }, } + # Process the payload through the pipeline + try: + payload = await process_pipeline_inlet_filter(request, payload, user, models) + except Exception as e: + raise e + try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: @@ -433,6 +455,12 @@ async def generate_queries( }, } + # Process the payload through the pipeline + try: + payload = await process_pipeline_inlet_filter(request, payload, user, models) + except Exception as e: + raise e + try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: @@ -514,6 +542,12 @@ async def generate_autocompletion( }, } + # Process the payload through the pipeline + try: + payload = await process_pipeline_inlet_filter(request, payload, user, models) + except Exception as e: + raise e + try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: @@ -571,7 +605,7 @@ async def generate_emoji( "stream": False, **( {"max_tokens": 4} - if models[task_model_id]["owned_by"] == "ollama" + if models[task_model_id].get("owned_by") == "ollama" else { "max_completion_tokens": 4, } @@ -584,6 +618,12 @@ async def generate_emoji( }, } + # Process the payload through the pipeline + try: + payload = await process_pipeline_inlet_filter(request, payload, user, models) + except Exception as e: + raise e + try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: @@ -644,6 +684,12 @@ async def generate_moa_response( }, } + # Process the payload through the pipeline + try: + payload = await process_pipeline_inlet_filter(request, payload, user, models) + except Exception as e: + raise e + try: return await generate_chat_completion(request, form_data=payload, user=user) except Exception as e: diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index d6a5c5532f..5e4109037d 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path from typing import Optional @@ -15,6 +16,10 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.tools import get_tools_specs from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) router = APIRouter() @@ -111,7 +116,7 @@ async def create_new_tools( detail=ERROR_MESSAGES.DEFAULT("Error creating tools"), ) except Exception as e: - print(e) + log.exception(f"Failed to load the tool by id {form_data.id}: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(str(e)), @@ -193,7 +198,7 @@ async def update_tools_by_id( "specs": specs, } - print(updated) + log.debug(updated) tools = Tools.update_tool_by_id(id, updated) if tools: @@ -343,7 +348,7 @@ async def update_tools_valves_by_id( Tools.update_tool_valves_by_id(id, valves.model_dump()) return valves.model_dump() except Exception as e: - print(e) + log.exception(f"Failed to update tool valves by id {id}: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(str(e)), @@ -421,7 +426,7 @@ async def update_tools_user_valves_by_id( ) return user_valves.model_dump() except Exception as e: - print(e) + log.exception(f"Failed to update user valves by id {id}: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(str(e)), diff --git a/backend/open_webui/routers/utils.py b/backend/open_webui/routers/utils.py index ea73e97596..b64adafb44 100644 --- a/backend/open_webui/routers/utils.py +++ b/backend/open_webui/routers/utils.py @@ -1,48 +1,84 @@ import black +import logging import markdown from open_webui.models.chats import ChatTitleMessagesForm from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, HTTPException, Response, status +from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from pydantic import BaseModel from starlette.responses import FileResponse + + from open_webui.utils.misc import get_gravatar_url from open_webui.utils.pdf_generator import PDFGenerator -from open_webui.utils.auth import get_admin_user +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.code_interpreter import execute_code_jupyter +from open_webui.env import SRC_LOG_LEVELS + + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) router = APIRouter() @router.get("/gravatar") -async def get_gravatar( - email: str, -): +async def get_gravatar(email: str, user=Depends(get_verified_user)): return get_gravatar_url(email) -class CodeFormatRequest(BaseModel): +class CodeForm(BaseModel): code: str @router.post("/code/format") -async def format_code(request: CodeFormatRequest): +async def format_code(form_data: CodeForm, user=Depends(get_verified_user)): try: - formatted_code = black.format_str(request.code, mode=black.Mode()) + formatted_code = black.format_str(form_data.code, mode=black.Mode()) return {"code": formatted_code} except black.NothingChanged: - return {"code": request.code} + return {"code": form_data.code} except Exception as e: raise HTTPException(status_code=400, detail=str(e)) +@router.post("/code/execute") +async def execute_code( + request: Request, form_data: CodeForm, user=Depends(get_verified_user) +): + if request.app.state.config.CODE_EXECUTION_ENGINE == "jupyter": + output = await execute_code_jupyter( + request.app.state.config.CODE_EXECUTION_JUPYTER_URL, + form_data.code, + ( + request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN + if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "token" + else None + ), + ( + request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD + if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "password" + else None + ), + request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT, + ) + + return output + else: + raise HTTPException( + status_code=400, + detail="Code execution engine not supported", + ) + + class MarkdownForm(BaseModel): md: str @router.post("/markdown") async def get_html_from_markdown( - form_data: MarkdownForm, + form_data: MarkdownForm, user=Depends(get_verified_user) ): return {"html": markdown.markdown(form_data.md)} @@ -54,7 +90,7 @@ class ChatForm(BaseModel): @router.post("/pdf") async def download_chat_as_pdf( - form_data: ChatTitleMessagesForm, + form_data: ChatTitleMessagesForm, user=Depends(get_verified_user) ): try: pdf_bytes = PDFGenerator(form_data).generate_chat_pdf() @@ -65,7 +101,7 @@ async def download_chat_as_pdf( headers={"Content-Disposition": "attachment;filename=chat.pdf"}, ) except Exception as e: - print(e) + log.exception(f"Error generating PDF: {e}") raise HTTPException(status_code=400, detail=str(e)) diff --git a/backend/open_webui/storage/provider.py b/backend/open_webui/storage/provider.py index b03cf0a7ec..2f31cbdafb 100644 --- a/backend/open_webui/storage/provider.py +++ b/backend/open_webui/storage/provider.py @@ -1,10 +1,12 @@ import os import shutil import json +import logging from abc import ABC, abstractmethod from typing import BinaryIO, Tuple import boto3 +from botocore.config import Config from botocore.exceptions import ClientError from open_webui.config import ( S3_ACCESS_KEY_ID, @@ -13,14 +15,27 @@ from open_webui.config import ( S3_KEY_PREFIX, S3_REGION_NAME, S3_SECRET_ACCESS_KEY, + S3_USE_ACCELERATE_ENDPOINT, + S3_ADDRESSING_STYLE, GCS_BUCKET_NAME, GOOGLE_APPLICATION_CREDENTIALS_JSON, + AZURE_STORAGE_ENDPOINT, + AZURE_STORAGE_CONTAINER_NAME, + AZURE_STORAGE_KEY, STORAGE_PROVIDER, UPLOAD_DIR, ) from google.cloud import storage from google.cloud.exceptions import GoogleCloudError, NotFound from open_webui.constants import ERROR_MESSAGES +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient +from azure.core.exceptions import ResourceNotFoundError +from open_webui.env import SRC_LOG_LEVELS + + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) class StorageProvider(ABC): @@ -65,7 +80,7 @@ class LocalStorageProvider(StorageProvider): if os.path.isfile(file_path): os.remove(file_path) else: - print(f"File {file_path} not found in local storage.") + log.warning(f"File {file_path} not found in local storage.") @staticmethod def delete_all_files() -> None: @@ -79,9 +94,9 @@ class LocalStorageProvider(StorageProvider): elif os.path.isdir(file_path): shutil.rmtree(file_path) # Remove the directory except Exception as e: - print(f"Failed to delete {file_path}. Reason: {e}") + log.exception(f"Failed to delete {file_path}. Reason: {e}") else: - print(f"Directory {UPLOAD_DIR} not found in local storage.") + log.warning(f"Directory {UPLOAD_DIR} not found in local storage.") class S3StorageProvider(StorageProvider): @@ -92,6 +107,12 @@ class S3StorageProvider(StorageProvider): endpoint_url=S3_ENDPOINT_URL, aws_access_key_id=S3_ACCESS_KEY_ID, aws_secret_access_key=S3_SECRET_ACCESS_KEY, + config=Config( + s3={ + "use_accelerate_endpoint": S3_USE_ACCELERATE_ENDPOINT, + "addressing_style": S3_ADDRESSING_STYLE, + }, + ), ) self.bucket_name = S3_BUCKET_NAME self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else "" @@ -221,6 +242,74 @@ class GCSStorageProvider(StorageProvider): LocalStorageProvider.delete_all_files() +class AzureStorageProvider(StorageProvider): + def __init__(self): + self.endpoint = AZURE_STORAGE_ENDPOINT + self.container_name = AZURE_STORAGE_CONTAINER_NAME + storage_key = AZURE_STORAGE_KEY + + if storage_key: + # Configure using the Azure Storage Account Endpoint and Key + self.blob_service_client = BlobServiceClient( + account_url=self.endpoint, credential=storage_key + ) + else: + # Configure using the Azure Storage Account Endpoint and DefaultAzureCredential + # If the key is not configured, then the DefaultAzureCredential will be used to support Managed Identity authentication + self.blob_service_client = BlobServiceClient( + account_url=self.endpoint, credential=DefaultAzureCredential() + ) + self.container_client = self.blob_service_client.get_container_client( + self.container_name + ) + + def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]: + """Handles uploading of the file to Azure Blob Storage.""" + contents, file_path = LocalStorageProvider.upload_file(file, filename) + try: + blob_client = self.container_client.get_blob_client(filename) + blob_client.upload_blob(contents, overwrite=True) + return contents, f"{self.endpoint}/{self.container_name}/{filename}" + except Exception as e: + raise RuntimeError(f"Error uploading file to Azure Blob Storage: {e}") + + def get_file(self, file_path: str) -> str: + """Handles downloading of the file from Azure Blob Storage.""" + try: + filename = file_path.split("/")[-1] + local_file_path = f"{UPLOAD_DIR}/{filename}" + blob_client = self.container_client.get_blob_client(filename) + with open(local_file_path, "wb") as download_file: + download_file.write(blob_client.download_blob().readall()) + return local_file_path + except ResourceNotFoundError as e: + raise RuntimeError(f"Error downloading file from Azure Blob Storage: {e}") + + def delete_file(self, file_path: str) -> None: + """Handles deletion of the file from Azure Blob Storage.""" + try: + filename = file_path.split("/")[-1] + blob_client = self.container_client.get_blob_client(filename) + blob_client.delete_blob() + except ResourceNotFoundError as e: + raise RuntimeError(f"Error deleting file from Azure Blob Storage: {e}") + + # Always delete from local storage + LocalStorageProvider.delete_file(file_path) + + def delete_all_files(self) -> None: + """Handles deletion of all files from Azure Blob Storage.""" + try: + blobs = self.container_client.list_blobs() + for blob in blobs: + self.container_client.delete_blob(blob.name) + except Exception as e: + raise RuntimeError(f"Error deleting all files from Azure Blob Storage: {e}") + + # Always delete from local storage + LocalStorageProvider.delete_all_files() + + def get_storage_provider(storage_provider: str): if storage_provider == "local": Storage = LocalStorageProvider() @@ -228,6 +317,8 @@ def get_storage_provider(storage_provider: str): Storage = S3StorageProvider() elif storage_provider == "gcs": Storage = GCSStorageProvider() + elif storage_provider == "azure": + Storage = AzureStorageProvider() else: raise RuntimeError(f"Unsupported storage provider: {storage_provider}") return Storage diff --git a/backend/open_webui/test/apps/webui/storage/test_provider.py b/backend/open_webui/test/apps/webui/storage/test_provider.py index 863106e75a..a5ef135043 100644 --- a/backend/open_webui/test/apps/webui/storage/test_provider.py +++ b/backend/open_webui/test/apps/webui/storage/test_provider.py @@ -7,6 +7,8 @@ from moto import mock_aws from open_webui.storage import provider from gcp_storage_emulator.server import create_server from google.cloud import storage +from azure.storage.blob import BlobServiceClient, ContainerClient, BlobClient +from unittest.mock import MagicMock def mock_upload_dir(monkeypatch, tmp_path): @@ -22,6 +24,7 @@ def test_imports(): provider.LocalStorageProvider provider.S3StorageProvider provider.GCSStorageProvider + provider.AzureStorageProvider provider.Storage @@ -32,6 +35,8 @@ def test_get_storage_provider(): assert isinstance(Storage, provider.S3StorageProvider) Storage = provider.get_storage_provider("gcs") assert isinstance(Storage, provider.GCSStorageProvider) + Storage = provider.get_storage_provider("azure") + assert isinstance(Storage, provider.AzureStorageProvider) with pytest.raises(RuntimeError): provider.get_storage_provider("invalid") @@ -48,6 +53,7 @@ def test_class_instantiation(): provider.LocalStorageProvider() provider.S3StorageProvider() provider.GCSStorageProvider() + provider.AzureStorageProvider() class TestLocalStorageProvider: @@ -272,3 +278,147 @@ class TestGCSStorageProvider: assert not (upload_dir / self.filename_extra).exists() assert self.Storage.bucket.get_blob(self.filename) == None assert self.Storage.bucket.get_blob(self.filename_extra) == None + + +class TestAzureStorageProvider: + def __init__(self): + super().__init__() + + @pytest.fixture(scope="class") + def setup_storage(self, monkeypatch): + # Create mock Blob Service Client and related clients + mock_blob_service_client = MagicMock() + mock_container_client = MagicMock() + mock_blob_client = MagicMock() + + # Set up return values for the mock + mock_blob_service_client.get_container_client.return_value = ( + mock_container_client + ) + mock_container_client.get_blob_client.return_value = mock_blob_client + + # Monkeypatch the Azure classes to return our mocks + monkeypatch.setattr( + azure.storage.blob, + "BlobServiceClient", + lambda *args, **kwargs: mock_blob_service_client, + ) + monkeypatch.setattr( + azure.storage.blob, + "ContainerClient", + lambda *args, **kwargs: mock_container_client, + ) + monkeypatch.setattr( + azure.storage.blob, "BlobClient", lambda *args, **kwargs: mock_blob_client + ) + + self.Storage = provider.AzureStorageProvider() + self.Storage.endpoint = "https://myaccount.blob.core.windows.net" + self.Storage.container_name = "my-container" + self.file_content = b"test content" + self.filename = "test.txt" + self.filename_extra = "test_extra.txt" + self.file_bytesio_empty = io.BytesIO() + + # Apply mocks to the Storage instance + self.Storage.blob_service_client = mock_blob_service_client + self.Storage.container_client = mock_container_client + + def test_upload_file(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + + # Simulate an error when container does not exist + self.Storage.container_client.get_blob_client.side_effect = Exception( + "Container does not exist" + ) + with pytest.raises(Exception): + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) + + # Reset side effect and create container + self.Storage.container_client.get_blob_client.side_effect = None + self.Storage.create_container() + contents, azure_file_path = self.Storage.upload_file( + io.BytesIO(self.file_content), self.filename + ) + + # Assertions + self.Storage.container_client.get_blob_client.assert_called_with(self.filename) + self.Storage.container_client.get_blob_client().upload_blob.assert_called_once_with( + self.file_content, overwrite=True + ) + assert contents == self.file_content + assert ( + azure_file_path + == f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}" + ) + assert (upload_dir / self.filename).exists() + assert (upload_dir / self.filename).read_bytes() == self.file_content + + with pytest.raises(ValueError): + self.Storage.upload_file(self.file_bytesio_empty, self.filename) + + def test_get_file(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + self.Storage.create_container() + + # Mock upload behavior + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) + # Mock blob download behavior + self.Storage.container_client.get_blob_client().download_blob().readall.return_value = ( + self.file_content + ) + + file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}" + file_path = self.Storage.get_file(file_url) + + assert file_path == str(upload_dir / self.filename) + assert (upload_dir / self.filename).exists() + assert (upload_dir / self.filename).read_bytes() == self.file_content + + def test_delete_file(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + self.Storage.create_container() + + # Mock file upload + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) + # Mock deletion + self.Storage.container_client.get_blob_client().delete_blob.return_value = None + + file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}" + self.Storage.delete_file(file_url) + + self.Storage.container_client.get_blob_client().delete_blob.assert_called_once() + assert not (upload_dir / self.filename).exists() + + def test_delete_all_files(self, monkeypatch, tmp_path): + upload_dir = mock_upload_dir(monkeypatch, tmp_path) + self.Storage.create_container() + + # Mock file uploads + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename) + self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra) + + # Mock listing and deletion behavior + self.Storage.container_client.list_blobs.return_value = [ + {"name": self.filename}, + {"name": self.filename_extra}, + ] + self.Storage.container_client.get_blob_client().delete_blob.return_value = None + + self.Storage.delete_all_files() + + self.Storage.container_client.list_blobs.assert_called_once() + self.Storage.container_client.get_blob_client().delete_blob.assert_any_call() + assert not (upload_dir / self.filename).exists() + assert not (upload_dir / self.filename_extra).exists() + + def test_get_file_not_found(self, monkeypatch): + self.Storage.create_container() + + file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}" + # Mock behavior to raise an error for missing blobs + self.Storage.container_client.get_blob_client().download_blob.side_effect = ( + Exception("Blob not found") + ) + with pytest.raises(Exception, match="Blob not found"): + self.Storage.get_file(file_url) diff --git a/backend/open_webui/utils/audit.py b/backend/open_webui/utils/audit.py new file mode 100644 index 0000000000..2d7ceabcb8 --- /dev/null +++ b/backend/open_webui/utils/audit.py @@ -0,0 +1,249 @@ +from contextlib import asynccontextmanager +from dataclasses import asdict, dataclass +from enum import Enum +import re +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Dict, + MutableMapping, + Optional, + cast, +) +import uuid + +from asgiref.typing import ( + ASGI3Application, + ASGIReceiveCallable, + ASGIReceiveEvent, + ASGISendCallable, + ASGISendEvent, + Scope as ASGIScope, +) +from loguru import logger +from starlette.requests import Request + +from open_webui.env import AUDIT_LOG_LEVEL, MAX_BODY_LOG_SIZE +from open_webui.utils.auth import get_current_user, get_http_authorization_cred +from open_webui.models.users import UserModel + + +if TYPE_CHECKING: + from loguru import Logger + + +@dataclass(frozen=True) +class AuditLogEntry: + # `Metadata` audit level properties + id: str + user: dict[str, Any] + audit_level: str + verb: str + request_uri: str + user_agent: Optional[str] = None + source_ip: Optional[str] = None + # `Request` audit level properties + request_object: Any = None + # `Request Response` level + response_object: Any = None + response_status_code: Optional[int] = None + + +class AuditLevel(str, Enum): + NONE = "NONE" + METADATA = "METADATA" + REQUEST = "REQUEST" + REQUEST_RESPONSE = "REQUEST_RESPONSE" + + +class AuditLogger: + """ + A helper class that encapsulates audit logging functionality. It uses Loguru’s logger with an auditable binding to ensure that audit log entries are filtered correctly. + + Parameters: + logger (Logger): An instance of Loguru’s logger. + """ + + def __init__(self, logger: "Logger"): + self.logger = logger.bind(auditable=True) + + def write( + self, + audit_entry: AuditLogEntry, + *, + log_level: str = "INFO", + extra: Optional[dict] = None, + ): + + entry = asdict(audit_entry) + + if extra: + entry["extra"] = extra + + self.logger.log( + log_level, + "", + **entry, + ) + + +class AuditContext: + """ + Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage. + + Attributes: + request_body (bytearray): Accumulated request payload. + response_body (bytearray): Accumulated response payload. + max_body_size (int): Maximum number of bytes to capture. + metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.). + """ + + def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE): + self.request_body = bytearray() + self.response_body = bytearray() + self.max_body_size = max_body_size + self.metadata: Dict[str, Any] = {} + + def add_request_chunk(self, chunk: bytes): + if len(self.request_body) < self.max_body_size: + self.request_body.extend( + chunk[: self.max_body_size - len(self.request_body)] + ) + + def add_response_chunk(self, chunk: bytes): + if len(self.response_body) < self.max_body_size: + self.response_body.extend( + chunk[: self.max_body_size - len(self.response_body)] + ) + + +class AuditLoggingMiddleware: + """ + ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle. + """ + + AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"} + + def __init__( + self, + app: ASGI3Application, + *, + excluded_paths: Optional[list[str]] = None, + max_body_size: int = MAX_BODY_LOG_SIZE, + audit_level: AuditLevel = AuditLevel.NONE, + ) -> None: + self.app = app + self.audit_logger = AuditLogger(logger) + self.excluded_paths = excluded_paths or [] + self.max_body_size = max_body_size + self.audit_level = audit_level + + async def __call__( + self, + scope: ASGIScope, + receive: ASGIReceiveCallable, + send: ASGISendCallable, + ) -> None: + if scope["type"] != "http": + return await self.app(scope, receive, send) + + request = Request(scope=cast(MutableMapping, scope)) + + if self._should_skip_auditing(request): + return await self.app(scope, receive, send) + + async with self._audit_context(request) as context: + + async def send_wrapper(message: ASGISendEvent) -> None: + if self.audit_level == AuditLevel.REQUEST_RESPONSE: + await self._capture_response(message, context) + + await send(message) + + original_receive = receive + + async def receive_wrapper() -> ASGIReceiveEvent: + nonlocal original_receive + message = await original_receive() + + if self.audit_level in ( + AuditLevel.REQUEST, + AuditLevel.REQUEST_RESPONSE, + ): + await self._capture_request(message, context) + + return message + + await self.app(scope, receive_wrapper, send_wrapper) + + @asynccontextmanager + async def _audit_context( + self, request: Request + ) -> AsyncGenerator[AuditContext, None]: + """ + async context manager that ensures that an audit log entry is recorded after the request is processed. + """ + context = AuditContext() + try: + yield context + finally: + await self._log_audit_entry(request, context) + + async def _get_authenticated_user(self, request: Request) -> UserModel: + + auth_header = request.headers.get("Authorization") + assert auth_header + user = get_current_user(request, None, get_http_authorization_cred(auth_header)) + + return user + + def _should_skip_auditing(self, request: Request) -> bool: + if ( + request.method not in {"POST", "PUT", "PATCH", "DELETE"} + or AUDIT_LOG_LEVEL == "NONE" + or not request.headers.get("authorization") + ): + return True + # match either /api//...(for the endpoint /api/chat case) or /api/v1//... + pattern = re.compile( + r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b" + ) + if pattern.match(request.url.path): + return True + + return False + + async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext): + if message["type"] == "http.request": + body = message.get("body", b"") + context.add_request_chunk(body) + + async def _capture_response(self, message: ASGISendEvent, context: AuditContext): + if message["type"] == "http.response.start": + context.metadata["response_status_code"] = message["status"] + + elif message["type"] == "http.response.body": + body = message.get("body", b"") + context.add_response_chunk(body) + + async def _log_audit_entry(self, request: Request, context: AuditContext): + try: + user = await self._get_authenticated_user(request) + + entry = AuditLogEntry( + id=str(uuid.uuid4()), + user=user.model_dump(include={"id", "name", "email", "role"}), + audit_level=self.audit_level.value, + verb=request.method, + request_uri=str(request.url), + response_status_code=context.metadata.get("response_status_code", None), + source_ip=request.client.host if request.client else None, + user_agent=request.headers.get("user-agent"), + request_object=context.request_body.decode("utf-8", errors="replace"), + response_object=context.response_body.decode("utf-8", errors="replace"), + ) + + self.audit_logger.write(entry) + except Exception as e: + logger.error(f"Failed to log audit entry: {str(e)}") diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index e4ad27c846..cbc8b15aed 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -5,6 +5,7 @@ import base64 import hmac import hashlib import requests +import os from datetime import UTC, datetime, timedelta @@ -13,15 +14,22 @@ from typing import Optional, Union, List, Dict from open_webui.models.users import Users from open_webui.constants import ERROR_MESSAGES -from open_webui.config import override_static -from open_webui.env import WEBUI_SECRET_KEY, TRUSTED_SIGNATURE_KEY +from open_webui.env import ( + WEBUI_SECRET_KEY, + TRUSTED_SIGNATURE_KEY, + STATIC_DIR, + SRC_LOG_LEVELS, +) -from fastapi import Depends, HTTPException, Request, Response, status +from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from passlib.context import CryptContext + logging.getLogger("passlib").setLevel(logging.ERROR) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["OAUTH"]) SESSION_SECRET = WEBUI_SECRET_KEY ALGORITHM = "HS256" @@ -47,6 +55,19 @@ def verify_signature(payload: str, signature: str) -> bool: return False +def override_static(path: str, content: str): + # Ensure path is safe + if "/" in path or ".." in path: + log.error(f"Invalid path: {path}") + return + + file_path = os.path.join(STATIC_DIR, path) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + with open(file_path, "wb") as f: + f.write(base64.b64decode(content)) # Convert Base64 back to raw binary + + def get_license_data(app, key): if key: try: @@ -69,11 +90,11 @@ def get_license_data(app, key): return True else: - print( + log.error( f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}" ) except Exception as ex: - print(f"License: Uncaught Exception: {ex}") + log.exception(f"License: Uncaught Exception: {ex}") return False @@ -129,6 +150,7 @@ def get_http_authorization_cred(auth_header: str): def get_current_user( request: Request, + background_tasks: BackgroundTasks, auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), ): token = None @@ -181,7 +203,10 @@ def get_current_user( detail=ERROR_MESSAGES.INVALID_TOKEN, ) else: - Users.update_user_last_active_by_id(user.id) + # Refresh the user's last active timestamp asynchronously + # to prevent blocking the request + if background_tasks: + background_tasks.add_task(Users.update_user_last_active_by_id, user.id) return user else: raise HTTPException( diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 73e4264bf4..74d0af4f7a 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -66,7 +66,7 @@ async def generate_direct_chat_completion( user: Any, models: dict, ): - print("generate_direct_chat_completion") + log.info("generate_direct_chat_completion") metadata = form_data.pop("metadata", {}) @@ -103,7 +103,7 @@ async def generate_direct_chat_completion( } ) - print("res", res) + log.info(f"res: {res}") if res.get("status", False): # Define a generator to stream responses @@ -200,7 +200,7 @@ async def generate_chat_completion( except Exception as e: raise e - if model["owned_by"] == "arena": + if model.get("owned_by") == "arena": model_ids = model.get("info", {}).get("meta", {}).get("model_ids") filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") if model_ids and filter_mode == "exclude": @@ -253,7 +253,7 @@ async def generate_chat_completion( return await generate_function_chat_completion( request, form_data, user=user, models=models ) - if model["owned_by"] == "ollama": + if model.get("owned_by") == "ollama": # Using /ollama/api/chat endpoint form_data = convert_payload_openai_to_ollama(form_data) response = await generate_ollama_chat_completion( @@ -285,7 +285,7 @@ chat_completion = generate_chat_completion async def chat_completed(request: Request, form_data: dict, user: Any): if not request.app.state.MODELS: - await get_all_models(request) + await get_all_models(request, user=user) if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { @@ -351,7 +351,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A raise Exception(f"Action not found: {action_id}") if not request.app.state.MODELS: - await get_all_models(request) + await get_all_models(request, user=user) if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { @@ -432,7 +432,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A ) ) except Exception as e: - print(e) + log.exception(f"Failed to get user values: {e}") params = {**params, "__user__": __user__} diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py index de51bd46e5..0ca754ed8d 100644 --- a/backend/open_webui/utils/filter.py +++ b/backend/open_webui/utils/filter.py @@ -1,6 +1,12 @@ import inspect +import logging + from open_webui.utils.plugin import load_function_module_by_id from open_webui.models.functions import Functions +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) def get_sorted_filter_ids(model): @@ -61,7 +67,12 @@ async def process_filter_functions( try: # Prepare parameters sig = inspect.signature(handler) - params = {"body": form_data} | { + + params = {"body": form_data} + if filter_type == "stream": + params = {"event": form_data} + + params = params | { k: v for k, v in { **extra_params, @@ -80,7 +91,7 @@ async def process_filter_functions( ) ) except Exception as e: - print(e) + log.exception(f"Failed to get user values: {e}") # Execute handler if inspect.iscoroutinefunction(handler): @@ -89,7 +100,7 @@ async def process_filter_functions( form_data = handler(**params) except Exception as e: - print(f"Error in {filter_type} handler {filter_id}: {e}") + log.exception(f"Error in {filter_type} handler {filter_id}: {e}") raise e # Handle file cleanup for inlet diff --git a/backend/open_webui/utils/logger.py b/backend/open_webui/utils/logger.py new file mode 100644 index 0000000000..2557610060 --- /dev/null +++ b/backend/open_webui/utils/logger.py @@ -0,0 +1,140 @@ +import json +import logging +import sys +from typing import TYPE_CHECKING + +from loguru import logger + +from open_webui.env import ( + AUDIT_LOG_FILE_ROTATION_SIZE, + AUDIT_LOG_LEVEL, + AUDIT_LOGS_FILE_PATH, + GLOBAL_LOG_LEVEL, +) + + +if TYPE_CHECKING: + from loguru import Record + + +def stdout_format(record: "Record") -> str: + """ + Generates a formatted string for log records that are output to the console. This format includes a timestamp, log level, source location (module, function, and line), the log message, and any extra data (serialized as JSON). + + Parameters: + record (Record): A Loguru record that contains logging details including time, level, name, function, line, message, and any extra context. + Returns: + str: A formatted log string intended for stdout. + """ + record["extra"]["extra_json"] = json.dumps(record["extra"]) + return ( + "{time:YYYY-MM-DD HH:mm:ss.SSS} | " + "{level: <8} | " + "{name}:{function}:{line} - " + "{message} - {extra[extra_json]}" + "\n{exception}" + ) + + +class InterceptHandler(logging.Handler): + """ + Intercepts log records from Python's standard logging module + and redirects them to Loguru's logger. + """ + + def emit(self, record): + """ + Called by the standard logging module for each log event. + It transforms the standard `LogRecord` into a format compatible with Loguru + and passes it to Loguru's logger. + """ + try: + level = logger.level(record.levelname).name + except ValueError: + level = record.levelno + + frame, depth = sys._getframe(6), 6 + while frame and frame.f_code.co_filename == logging.__file__: + frame = frame.f_back + depth += 1 + + logger.opt(depth=depth, exception=record.exc_info).log( + level, record.getMessage() + ) + + +def file_format(record: "Record"): + """ + Formats audit log records into a structured JSON string for file output. + + Parameters: + record (Record): A Loguru record containing extra audit data. + Returns: + str: A JSON-formatted string representing the audit data. + """ + + audit_data = { + "id": record["extra"].get("id", ""), + "timestamp": int(record["time"].timestamp()), + "user": record["extra"].get("user", dict()), + "audit_level": record["extra"].get("audit_level", ""), + "verb": record["extra"].get("verb", ""), + "request_uri": record["extra"].get("request_uri", ""), + "response_status_code": record["extra"].get("response_status_code", 0), + "source_ip": record["extra"].get("source_ip", ""), + "user_agent": record["extra"].get("user_agent", ""), + "request_object": record["extra"].get("request_object", b""), + "response_object": record["extra"].get("response_object", b""), + "extra": record["extra"].get("extra", {}), + } + + record["extra"]["file_extra"] = json.dumps(audit_data, default=str) + return "{extra[file_extra]}\n" + + +def start_logger(): + """ + Initializes and configures Loguru's logger with distinct handlers: + + A console (stdout) handler for general log messages (excluding those marked as auditable). + An optional file handler for audit logs if audit logging is enabled. + Additionally, this function reconfigures Python’s standard logging to route through Loguru and adjusts logging levels for Uvicorn. + + Parameters: + enable_audit_logging (bool): Determines whether audit-specific log entries should be recorded to file. + """ + logger.remove() + + logger.add( + sys.stdout, + level=GLOBAL_LOG_LEVEL, + format=stdout_format, + filter=lambda record: "auditable" not in record["extra"], + ) + + if AUDIT_LOG_LEVEL != "NONE": + try: + logger.add( + AUDIT_LOGS_FILE_PATH, + level="INFO", + rotation=AUDIT_LOG_FILE_ROTATION_SIZE, + compression="zip", + format=file_format, + filter=lambda record: record["extra"].get("auditable") is True, + ) + except Exception as e: + logger.error(f"Failed to initialize audit log file handler: {str(e)}") + + logging.basicConfig( + handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True + ) + for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]: + uvicorn_logger = logging.getLogger(uvicorn_logger_name) + uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL) + uvicorn_logger.handlers = [] + for uvicorn_logger_name in ["uvicorn.access"]: + uvicorn_logger = logging.getLogger(uvicorn_logger_name) + uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL) + uvicorn_logger.handlers = [InterceptHandler()] + + logger.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}") diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index d52be24879..de2b9c468f 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -322,78 +322,95 @@ async def chat_web_search_handler( ) return form_data - searchQuery = queries[0] + all_results = [] - await event_emitter( - { - "type": "status", - "data": { - "action": "web_search", - "description": 'Searching "{{searchQuery}}"', - "query": searchQuery, - "done": False, - }, - } - ) - - try: - - results = await process_web_search( - request, - SearchForm( - **{ + for searchQuery in queries: + await event_emitter( + { + "type": "status", + "data": { + "action": "web_search", + "description": 'Searching "{{searchQuery}}"', "query": searchQuery, - } - ), - user, + "done": False, + }, + } ) - if results: - await event_emitter( - { - "type": "status", - "data": { - "action": "web_search", - "description": "Searched {{count}} sites", + try: + results = await process_web_search( + request, + SearchForm( + **{ "query": searchQuery, - "urls": results["filenames"], - "done": True, - }, - } + } + ), + user=user, ) - files = form_data.get("files", []) - files.append( - { - "collection_name": results["collection_name"], - "name": searchQuery, - "type": "web_search_results", - "urls": results["filenames"], - } - ) - form_data["files"] = files - else: + if results: + all_results.append(results) + files = form_data.get("files", []) + + if results.get("collection_name"): + files.append( + { + "collection_name": results["collection_name"], + "name": searchQuery, + "type": "web_search", + "urls": results["filenames"], + } + ) + elif results.get("docs"): + files.append( + { + "docs": results.get("docs", []), + "name": searchQuery, + "type": "web_search", + "urls": results["filenames"], + } + ) + + form_data["files"] = files + except Exception as e: + log.exception(e) await event_emitter( { "type": "status", "data": { "action": "web_search", - "description": "No search results found", + "description": 'Error searching "{{searchQuery}}"', "query": searchQuery, "done": True, "error": True, }, } ) - except Exception as e: - log.exception(e) + + if all_results: + urls = [] + for results in all_results: + if "filenames" in results: + urls.extend(results["filenames"]) + await event_emitter( { "type": "status", "data": { "action": "web_search", - "description": 'Error searching "{{searchQuery}}"', - "query": searchQuery, + "description": "Searched {{count}} sites", + "urls": urls, + "done": True, + }, + } + ) + else: + await event_emitter( + { + "type": "status", + "data": { + "action": "web_search", + "description": "No search results found", "done": True, "error": True, }, @@ -503,6 +520,7 @@ async def chat_completion_files_handler( sources = [] if files := body.get("metadata", {}).get("files", None): + queries = [] try: queries_response = await generate_queries( request, @@ -528,8 +546,8 @@ async def chat_completion_files_handler( queries_response = {"queries": [queries_response]} queries = queries_response.get("queries", []) - except Exception as e: - queries = [] + except: + pass if len(queries) == 0: queries = [get_last_user_message(body["messages"])] @@ -541,6 +559,7 @@ async def chat_completion_files_handler( sources = await loop.run_in_executor( executor, lambda: get_sources_from_files( + request=request, files=files, queries=queries, embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION( @@ -550,9 +569,9 @@ async def chat_completion_files_handler( reranking_function=request.app.state.rf, r=request.app.state.config.RELEVANCE_THRESHOLD, hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, + full_context=request.app.state.config.RAG_FULL_CONTEXT, ), ) - except Exception as e: log.exception(e) @@ -728,6 +747,7 @@ async def process_chat_payload(request, form_data, metadata, user, model): tool_ids = form_data.pop("tool_ids", None) files = form_data.pop("files", None) + # Remove files duplicates if files: files = list({json.dumps(f, sort_keys=True): f for f in files}.values()) @@ -785,8 +805,6 @@ async def process_chat_payload(request, form_data, metadata, user, model): if len(sources) > 0: context_string = "" for source_idx, source in enumerate(sources): - source_id = source.get("source", {}).get("name", "") - if "document" in source: for doc_idx, doc_context in enumerate(source["document"]): context_string += f"{source_idx}{doc_context}\n" @@ -806,7 +824,7 @@ async def process_chat_payload(request, form_data, metadata, user, model): # Workaround for Ollama 2.0+ system prompt issue # TODO: replace with add_or_update_system_message - if model["owned_by"] == "ollama": + if model.get("owned_by") == "ollama": form_data["messages"] = prepend_to_first_user_message_content( rag_template( request.app.state.config.RAG_TEMPLATE, context_string, prompt @@ -1038,6 +1056,21 @@ async def process_chat_response( ): return response + extra_params = { + "__event_emitter__": event_emitter, + "__event_call__": event_caller, + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + "__metadata__": metadata, + "__request__": request, + "__model__": metadata.get("model"), + } + filter_ids = get_sorted_filter_ids(form_data.get("model")) + # Streaming response if event_emitter and event_caller: task_id = str(uuid4()) # Create a unique task ID. @@ -1117,12 +1150,12 @@ async def process_chat_response( if reasoning_duration is not None: if raw: - content = f'{content}\n<{block["tag"]}>{block["content"]}\n' + content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n' else: content = f'{content}\n
\nThought for {reasoning_duration} seconds\n{reasoning_display_content}\n
\n' else: if raw: - content = f'{content}\n<{block["tag"]}>{block["content"]}\n' + content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n' else: content = f'{content}\n
\nThinking…\n{reasoning_display_content}\n
\n' @@ -1218,9 +1251,9 @@ async def process_chat_response( return attributes if content_blocks[-1]["type"] == "text": - for tag in tags: + for start_tag, end_tag in tags: # Match start tag e.g., or - start_tag_pattern = rf"<{tag}(\s.*?)?>" + start_tag_pattern = rf"<{re.escape(start_tag)}(\s.*?)?>" match = re.search(start_tag_pattern, content) if match: attr_content = ( @@ -1253,7 +1286,8 @@ async def process_chat_response( content_blocks.append( { "type": content_type, - "tag": tag, + "start_tag": start_tag, + "end_tag": end_tag, "attributes": attributes, "content": "", "started_at": time.time(), @@ -1265,9 +1299,10 @@ async def process_chat_response( break elif content_blocks[-1]["type"] == content_type: - tag = content_blocks[-1]["tag"] + start_tag = content_blocks[-1]["start_tag"] + end_tag = content_blocks[-1]["end_tag"] # Match end tag e.g., - end_tag_pattern = rf"" + end_tag_pattern = rf"<{re.escape(end_tag)}>" # Check if the content has the end tag if re.search(end_tag_pattern, content): @@ -1275,7 +1310,7 @@ async def process_chat_response( block_content = content_blocks[-1]["content"] # Strip start and end tags from the content - start_tag_pattern = rf"<{tag}(.*?)>" + start_tag_pattern = rf"<{re.escape(start_tag)}(.*?)>" block_content = re.sub( start_tag_pattern, "", block_content ).strip() @@ -1340,7 +1375,7 @@ async def process_chat_response( # Clean processed content content = re.sub( - rf"<{tag}(.*?)>(.|\n)*?", + rf"<{re.escape(start_tag)}(.*?)>(.|\n)*?<{re.escape(end_tag)}>", "", content, flags=re.DOTALL, @@ -1353,7 +1388,22 @@ async def process_chat_response( ) tool_calls = [] - content = message.get("content", "") if message else "" + + last_assistant_message = None + try: + if form_data["messages"][-1]["role"] == "assistant": + last_assistant_message = get_last_assistant_message( + form_data["messages"] + ) + except Exception as e: + pass + + content = ( + message.get("content", "") + if message + else last_assistant_message if last_assistant_message else "" + ) + content_blocks = [ { "type": "text", @@ -1363,19 +1413,24 @@ async def process_chat_response( # We might want to disable this by default DETECT_REASONING = True + DETECT_SOLUTION = True DETECT_CODE_INTERPRETER = metadata.get("features", {}).get( "code_interpreter", False ) reasoning_tags = [ - "think", - "thinking", - "reason", - "reasoning", - "thought", - "Thought", + ("think", "/think"), + ("thinking", "/thinking"), + ("reason", "/reason"), + ("reasoning", "/reasoning"), + ("thought", "/thought"), + ("Thought", "/Thought"), + ("|begin_of_thought|", "|end_of_thought|"), ] - code_interpreter_tags = ["code_interpreter"] + + code_interpreter_tags = [("code_interpreter", "/code_interpreter")] + + solution_tags = [("|begin_of_solution|", "|end_of_solution|")] try: for event in events: @@ -1419,119 +1474,154 @@ async def process_chat_response( try: data = json.loads(data) - if "selected_model_id" in data: - model_id = data["selected_model_id"] - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { - "selectedModelId": model_id, - }, - ) - else: - choices = data.get("choices", []) - if not choices: - continue + data, _ = await process_filter_functions( + request=request, + filter_ids=filter_ids, + filter_type="stream", + form_data=data, + extra_params=extra_params, + ) - delta = choices[0].get("delta", {}) - delta_tool_calls = delta.get("tool_calls", None) - - if delta_tool_calls: - for delta_tool_call in delta_tool_calls: - tool_call_index = delta_tool_call.get("index") - - if tool_call_index is not None: - if ( - len(response_tool_calls) - <= tool_call_index - ): - response_tool_calls.append( - delta_tool_call - ) - else: - delta_name = delta_tool_call.get( - "function", {} - ).get("name") - delta_arguments = delta_tool_call.get( - "function", {} - ).get("arguments") - - if delta_name: - response_tool_calls[ - tool_call_index - ]["function"]["name"] += delta_name - - if delta_arguments: - response_tool_calls[ - tool_call_index - ]["function"][ - "arguments" - ] += delta_arguments - - value = delta.get("content") - - if value: - content = f"{content}{value}" - - if not content_blocks: - content_blocks.append( - { - "type": "text", - "content": "", - } - ) - - content_blocks[-1]["content"] = ( - content_blocks[-1]["content"] + value + if data: + if "selected_model_id" in data: + model_id = data["selected_model_id"] + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "selectedModelId": model_id, + }, ) - - if DETECT_REASONING: - content, content_blocks, _ = ( - tag_content_handler( - "reasoning", - reasoning_tags, - content, - content_blocks, + else: + choices = data.get("choices", []) + if not choices: + usage = data.get("usage", {}) + if usage: + await event_emitter( + { + "type": "chat:completion", + "data": { + "usage": usage, + }, + } ) + continue + + delta = choices[0].get("delta", {}) + delta_tool_calls = delta.get("tool_calls", None) + + if delta_tool_calls: + for delta_tool_call in delta_tool_calls: + tool_call_index = delta_tool_call.get( + "index" + ) + + if tool_call_index is not None: + if ( + len(response_tool_calls) + <= tool_call_index + ): + response_tool_calls.append( + delta_tool_call + ) + else: + delta_name = delta_tool_call.get( + "function", {} + ).get("name") + delta_arguments = ( + delta_tool_call.get( + "function", {} + ).get("arguments") + ) + + if delta_name: + response_tool_calls[ + tool_call_index + ]["function"][ + "name" + ] += delta_name + + if delta_arguments: + response_tool_calls[ + tool_call_index + ]["function"][ + "arguments" + ] += delta_arguments + + value = delta.get("content") + + if value: + content = f"{content}{value}" + + if not content_blocks: + content_blocks.append( + { + "type": "text", + "content": "", + } + ) + + content_blocks[-1]["content"] = ( + content_blocks[-1]["content"] + value ) - if DETECT_CODE_INTERPRETER: - content, content_blocks, end = ( - tag_content_handler( - "code_interpreter", - code_interpreter_tags, - content, - content_blocks, + if DETECT_REASONING: + content, content_blocks, _ = ( + tag_content_handler( + "reasoning", + reasoning_tags, + content, + content_blocks, + ) ) - ) - if end: - break + if DETECT_CODE_INTERPRETER: + content, content_blocks, end = ( + tag_content_handler( + "code_interpreter", + code_interpreter_tags, + content, + content_blocks, + ) + ) - if ENABLE_REALTIME_CHAT_SAVE: - # Save message in the database - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { + if end: + break + + if DETECT_SOLUTION: + content, content_blocks, _ = ( + tag_content_handler( + "solution", + solution_tags, + content, + content_blocks, + ) + ) + + if ENABLE_REALTIME_CHAT_SAVE: + # Save message in the database + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "content": serialize_content_blocks( + content_blocks + ), + }, + ) + else: + data = { "content": serialize_content_blocks( content_blocks ), - }, - ) - else: - data = { - "content": serialize_content_blocks( - content_blocks - ), - } + } - await event_emitter( - { - "type": "chat:completion", - "data": data, - } - ) + await event_emitter( + { + "type": "chat:completion", + "data": data, + } + ) except Exception as e: done = "data: [DONE]" in line if done: @@ -1736,6 +1826,7 @@ async def process_chat_response( == "password" else None ), + request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT, ) else: output = { @@ -1829,7 +1920,10 @@ async def process_chat_response( } ) - print(content_blocks, serialize_content_blocks(content_blocks)) + log.info(f"content_blocks={content_blocks}") + log.info( + f"serialize_content_blocks={serialize_content_blocks(content_blocks)}" + ) try: res = await generate_chat_completion( @@ -1900,7 +1994,7 @@ async def process_chat_response( await background_tasks_handler() except asyncio.CancelledError: - print("Task was cancelled!") + log.warning("Task was cancelled!") await event_emitter({"type": "task-cancelled"}) if not ENABLE_REALTIME_CHAT_SAVE: @@ -1921,17 +2015,34 @@ async def process_chat_response( return {"status": True, "task_id": task_id} else: - # Fallback to the original response async def stream_wrapper(original_generator, events): def wrap_item(item): return f"data: {item}\n\n" for event in events: - yield wrap_item(json.dumps(event)) + event, _ = await process_filter_functions( + request=request, + filter_ids=filter_ids, + filter_type="stream", + form_data=event, + extra_params=extra_params, + ) + + if event: + yield wrap_item(json.dumps(event)) async for data in original_generator: - yield data + data, _ = await process_filter_functions( + request=request, + filter_ids=filter_ids, + filter_type="stream", + form_data=data, + extra_params=extra_params, + ) + + if data: + yield data return StreamingResponse( stream_wrapper(response.body_iterator, events), diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index 8ab7433169..85e7f6415e 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -2,6 +2,7 @@ import hashlib import re import time import uuid +import logging from datetime import timedelta from pathlib import Path from typing import Callable, Optional @@ -9,6 +10,10 @@ import json import collections.abc +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) def deep_update(d, u): @@ -413,7 +418,7 @@ def parse_ollama_modelfile(model_text): elif param_type is bool: value = value.lower() == "true" except Exception as e: - print(e) + log.exception(f"Failed to parse parameter {param}: {e}") continue data["params"][param] = value diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index 975f8cb095..149e41a41b 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -22,6 +22,7 @@ from open_webui.config import ( ) from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL +from open_webui.models.users import UserModel logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) @@ -29,17 +30,17 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) -async def get_all_base_models(request: Request): +async def get_all_base_models(request: Request, user: UserModel = None): function_models = [] openai_models = [] ollama_models = [] if request.app.state.config.ENABLE_OPENAI_API: - openai_models = await openai.get_all_models(request) + openai_models = await openai.get_all_models(request, user=user) openai_models = openai_models["data"] if request.app.state.config.ENABLE_OLLAMA_API: - ollama_models = await ollama.get_all_models(request) + ollama_models = await ollama.get_all_models(request, user=user) ollama_models = [ { "id": model["model"], @@ -58,8 +59,8 @@ async def get_all_base_models(request: Request): return models -async def get_all_models(request): - models = await get_all_base_models(request) +async def get_all_models(request, user: UserModel = None): + models = await get_all_base_models(request, user=user) # If there are no models, return an empty list if len(models) == 0: @@ -142,7 +143,7 @@ async def get_all_models(request): custom_model.base_model_id == model["id"] or custom_model.base_model_id == model["id"].split(":")[0] ): - owned_by = model["owned_by"] + owned_by = model.get("owned_by", "unknown owner") if "pipe" in model: pipe = model["pipe"] break diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index a635853d69..2af54c19d7 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -140,7 +140,14 @@ class OAuthManager: log.debug("Running OAUTH Group management") oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM - user_oauth_groups: list[str] = user_data.get(oauth_claim, list()) + # Nested claim search for groups claim + if oauth_claim: + claim_data = user_data + nested_claims = oauth_claim.split(".") + for nested_claim in nested_claims: + claim_data = claim_data.get(nested_claim, {}) + user_oauth_groups = claim_data if isinstance(claim_data, list) else [] + user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id) all_available_groups: list[GroupModel] = Groups.get_groups() @@ -239,11 +246,46 @@ class OAuthManager: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) provider_sub = f"{provider}@{sub}" email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM - email = user_data.get(email_claim, "").lower() + email = user_data.get(email_claim, "") # We currently mandate that email addresses are provided if not email: - log.warning(f"OAuth callback failed, email is missing: {user_data}") - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + # If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email + if provider == "github": + try: + access_token = token.get("access_token") + headers = {"Authorization": f"Bearer {access_token}"} + async with aiohttp.ClientSession() as session: + async with session.get( + "https://api.github.com/user/emails", headers=headers + ) as resp: + if resp.ok: + emails = await resp.json() + # use the primary email as the user's email + primary_email = next( + (e["email"] for e in emails if e.get("primary")), + None, + ) + if primary_email: + email = primary_email + else: + log.warning( + "No primary email found in GitHub response" + ) + raise HTTPException( + 400, detail=ERROR_MESSAGES.INVALID_CRED + ) + else: + log.warning("Failed to fetch GitHub email") + raise HTTPException( + 400, detail=ERROR_MESSAGES.INVALID_CRED + ) + except Exception as e: + log.warning(f"Error fetching GitHub email: {e}") + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + else: + log.warning(f"OAuth callback failed, email is missing: {user_data}") + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + email = email.lower() if ( "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS and email.split("@")[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS @@ -273,21 +315,10 @@ class OAuthManager: if not user: user_count = Users.get_num_users() - if ( - request.app.state.USER_COUNT - and user_count >= request.app.state.USER_COUNT - ): - raise HTTPException( - 403, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - # If the user does not exist, check if signups are enabled if auth_manager_config.ENABLE_OAUTH_SIGNUP: # Check if an existing user with the same email already exists - existing_user = Users.get_user_by_email( - user_data.get("email", "").lower() - ) + existing_user = Users.get_user_by_email(email) if existing_user: raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index d078362ee5..46656cc82b 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -4,6 +4,7 @@ from open_webui.utils.misc import ( ) from typing import Callable, Optional +import json # inplace function: form_data is modified @@ -67,38 +68,49 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: - opts = [ - "temperature", - "top_p", - "seed", - "mirostat", - "mirostat_eta", - "mirostat_tau", - "num_ctx", - "num_batch", - "num_keep", - "repeat_last_n", - "tfs_z", - "top_k", - "min_p", - "use_mmap", - "use_mlock", - "num_thread", - "num_gpu", - ] - mappings = {i: lambda x: x for i in opts} - form_data = apply_model_params_to_body(params, form_data, mappings) - + # Convert OpenAI parameter names to Ollama parameter names if needed. name_differences = { "max_tokens": "num_predict", - "frequency_penalty": "repeat_penalty", } for key, value in name_differences.items(): if (param := params.get(key, None)) is not None: - form_data[value] = param + # Copy the parameter to new name then delete it, to prevent Ollama warning of invalid option provided + params[value] = params[key] + del params[key] - return form_data + # See https://github.com/ollama/ollama/blob/main/docs/api.md#request-8 + mappings = { + "temperature": float, + "top_p": float, + "seed": lambda x: x, + "mirostat": int, + "mirostat_eta": float, + "mirostat_tau": float, + "num_ctx": int, + "num_batch": int, + "num_keep": int, + "num_predict": int, + "repeat_last_n": int, + "top_k": int, + "min_p": float, + "typical_p": float, + "repeat_penalty": float, + "presence_penalty": float, + "frequency_penalty": float, + "penalize_newline": bool, + "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], + "numa": bool, + "num_gpu": int, + "main_gpu": int, + "low_vram": bool, + "vocab_only": bool, + "use_mmap": bool, + "use_mlock": bool, + "num_thread": int, + } + + return apply_model_params_to_body(params, form_data, mappings) def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]: @@ -109,11 +121,38 @@ def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]: new_message = {"role": message["role"]} content = message.get("content", []) + tool_calls = message.get("tool_calls", None) + tool_call_id = message.get("tool_call_id", None) # Check if the content is a string (just a simple message) - if isinstance(content, str): + if isinstance(content, str) and not tool_calls: # If the content is a string, it's pure text new_message["content"] = content + + # If message is a tool call, add the tool call id to the message + if tool_call_id: + new_message["tool_call_id"] = tool_call_id + + elif tool_calls: + # If tool calls are present, add them to the message + ollama_tool_calls = [] + for tool_call in tool_calls: + ollama_tool_call = { + "index": tool_call.get("index", 0), + "id": tool_call.get("id", None), + "function": { + "name": tool_call.get("function", {}).get("name", ""), + "arguments": json.loads( + tool_call.get("function", {}).get("arguments", {}) + ), + }, + } + ollama_tool_calls.append(ollama_tool_call) + new_message["tool_calls"] = ollama_tool_calls + + # Put the content to empty string (Ollama requires an empty string for tool calls) + new_message["content"] = "" + else: # Otherwise, assume the content is a list of dicts, e.g., text followed by an image URL content_text = "" @@ -174,33 +213,28 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict: ollama_payload["format"] = openai_payload["format"] # If there are advanced parameters in the payload, format them in Ollama's options field - ollama_options = {} - if openai_payload.get("options"): ollama_payload["options"] = openai_payload["options"] ollama_options = openai_payload["options"] - # Handle parameters which map directly - for param in ["temperature", "top_p", "seed"]: - if param in openai_payload: - ollama_options[param] = openai_payload[param] + # Re-Mapping OpenAI's `max_tokens` -> Ollama's `num_predict` + if "max_tokens" in ollama_options: + ollama_options["num_predict"] = ollama_options["max_tokens"] + del ollama_options[ + "max_tokens" + ] # To prevent Ollama warning of invalid option provided - # Mapping OpenAI's `max_tokens` -> Ollama's `num_predict` - if "max_completion_tokens" in openai_payload: - ollama_options["num_predict"] = openai_payload["max_completion_tokens"] - elif "max_tokens" in openai_payload: - ollama_options["num_predict"] = openai_payload["max_tokens"] + # Ollama lacks a "system" prompt option. It has to be provided as a direct parameter, so we copy it down. + if "system" in ollama_options: + ollama_payload["system"] = ollama_options["system"] + del ollama_options[ + "system" + ] # To prevent Ollama warning of invalid option provided - # Handle frequency / presence_penalty, which needs renaming and checking - if "frequency_penalty" in openai_payload: - ollama_options["repeat_penalty"] = openai_payload["frequency_penalty"] - - if "presence_penalty" in openai_payload and "penalty" not in ollama_options: - # We are assuming presence penalty uses a similar concept in Ollama, which needs custom handling if exists. - ollama_options["new_topic_penalty"] = openai_payload["presence_penalty"] - - # Add options to payload if any have been set - if ollama_options: + # If there is the "stop" parameter in the openai_payload, remap it to the ollama_payload.options + if "stop" in openai_payload: + ollama_options = ollama_payload.get("options", {}) + ollama_options["stop"] = openai_payload.get("stop") ollama_payload["options"] = ollama_options if "metadata" in openai_payload: diff --git a/backend/open_webui/utils/plugin.py b/backend/open_webui/utils/plugin.py index d6e24d6b93..e3fe9237f5 100644 --- a/backend/open_webui/utils/plugin.py +++ b/backend/open_webui/utils/plugin.py @@ -45,7 +45,7 @@ def extract_frontmatter(content): frontmatter[key.strip()] = value.strip() except Exception as e: - print(f"An error occurred: {e}") + log.exception(f"Failed to extract frontmatter: {e}") return {} return frontmatter diff --git a/backend/open_webui/utils/response.py b/backend/open_webui/utils/response.py index f9979b4a27..8c3f1a58eb 100644 --- a/backend/open_webui/utils/response.py +++ b/backend/open_webui/utils/response.py @@ -24,17 +24,8 @@ def convert_ollama_tool_call_to_openai(tool_calls: dict) -> dict: return openai_tool_calls -def convert_response_ollama_to_openai(ollama_response: dict) -> dict: - model = ollama_response.get("model", "ollama") - message_content = ollama_response.get("message", {}).get("content", "") - tool_calls = ollama_response.get("message", {}).get("tool_calls", None) - openai_tool_calls = None - - if tool_calls: - openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls) - - data = ollama_response - usage = { +def convert_ollama_usage_to_openai(data: dict) -> dict: + return { "response_token/s": ( round( ( @@ -66,14 +57,42 @@ def convert_response_ollama_to_openai(ollama_response: dict) -> dict: "total_duration": data.get("total_duration", 0), "load_duration": data.get("load_duration", 0), "prompt_eval_count": data.get("prompt_eval_count", 0), + "prompt_tokens": int( + data.get("prompt_eval_count", 0) + ), # This is the OpenAI compatible key "prompt_eval_duration": data.get("prompt_eval_duration", 0), "eval_count": data.get("eval_count", 0), + "completion_tokens": int( + data.get("eval_count", 0) + ), # This is the OpenAI compatible key "eval_duration": data.get("eval_duration", 0), "approximate_total": (lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s")( (data.get("total_duration", 0) or 0) // 1_000_000_000 ), + "total_tokens": int( # This is the OpenAI compatible key + data.get("prompt_eval_count", 0) + data.get("eval_count", 0) + ), + "completion_tokens_details": { # This is the OpenAI compatible key + "reasoning_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0, + }, } + +def convert_response_ollama_to_openai(ollama_response: dict) -> dict: + model = ollama_response.get("model", "ollama") + message_content = ollama_response.get("message", {}).get("content", "") + tool_calls = ollama_response.get("message", {}).get("tool_calls", None) + openai_tool_calls = None + + if tool_calls: + openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls) + + data = ollama_response + + usage = convert_ollama_usage_to_openai(data) + response = openai_chat_completion_message_template( model, message_content, openai_tool_calls, usage ) @@ -85,7 +104,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response) data = json.loads(data) model = data.get("model", "ollama") - message_content = data.get("message", {}).get("content", "") + message_content = data.get("message", {}).get("content", None) tool_calls = data.get("message", {}).get("tool_calls", None) openai_tool_calls = None @@ -96,48 +115,10 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response) usage = None if done: - usage = { - "response_token/s": ( - round( - ( - ( - data.get("eval_count", 0) - / ((data.get("eval_duration", 0) / 10_000_000)) - ) - * 100 - ), - 2, - ) - if data.get("eval_duration", 0) > 0 - else "N/A" - ), - "prompt_token/s": ( - round( - ( - ( - data.get("prompt_eval_count", 0) - / ((data.get("prompt_eval_duration", 0) / 10_000_000)) - ) - * 100 - ), - 2, - ) - if data.get("prompt_eval_duration", 0) > 0 - else "N/A" - ), - "total_duration": data.get("total_duration", 0), - "load_duration": data.get("load_duration", 0), - "prompt_eval_count": data.get("prompt_eval_count", 0), - "prompt_eval_duration": data.get("prompt_eval_duration", 0), - "eval_count": data.get("eval_count", 0), - "eval_duration": data.get("eval_duration", 0), - "approximate_total": ( - lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s" - )((data.get("total_duration", 0) or 0) // 1_000_000_000), - } + usage = convert_ollama_usage_to_openai(data) data = openai_chat_chunk_message_template( - model, message_content if not done else None, openai_tool_calls, usage + model, message_content, openai_tool_calls, usage ) line = f"data: {json.dumps(data)}\n\n" diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 3d8c05d455..5663ce2acb 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -22,7 +22,7 @@ def get_task_model_id( # Set the task model task_model_id = default_model_id # Check if the user has a custom task model and use that model - if models[task_model_id]["owned_by"] == "ollama": + if models[task_model_id].get("owned_by") == "ollama": if task_model and task_model in models: task_model_id = task_model else: diff --git a/backend/requirements.txt b/backend/requirements.txt index 9b859b84a0..616827144c 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,10 +1,10 @@ fastapi==0.115.7 uvicorn[standard]==0.30.6 -pydantic==2.9.2 +pydantic==2.10.6 python-multipart==0.0.18 python-socketio==5.11.3 -python-jose==3.3.0 +python-jose==3.4.0 passlib[bcrypt]==1.7.4 requests==2.32.3 @@ -31,6 +31,9 @@ APScheduler==3.10.4 RestrictedPython==8.0 +loguru==0.7.2 +asgiref==3.8.1 + # AI libraries openai anthropic @@ -45,7 +48,7 @@ chromadb==0.6.2 pymilvus==2.5.0 qdrant-client~=1.12.0 opensearch-py==2.8.0 - +playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml transformers sentence-transformers==3.3.1 @@ -59,7 +62,7 @@ fpdf2==2.8.2 pymdown-extensions==10.14.2 docx2txt==0.8 python-pptx==1.0.0 -unstructured==0.16.11 +unstructured==0.16.17 nltk==3.9.1 Markdown==3.7 pypandoc==1.13 @@ -71,6 +74,7 @@ validators==0.34.0 psutil sentencepiece soundfile==0.13.1 +azure-ai-documentintelligence==1.0.0 opencv-python-headless==4.11.0.86 rapidocr-onnxruntime==1.3.24 @@ -103,5 +107,12 @@ pytest-docker~=3.1.1 googleapis-common-protos==1.63.2 google-cloud-storage==2.19.0 +azure-identity==1.20.0 +azure-storage-blob==12.24.1 + + ## LDAP ldap3==2.9.1 + +## Firecrawl +firecrawl-py==1.12.0 diff --git a/backend/start.sh b/backend/start.sh index a945acb62e..671c22ff73 100755 --- a/backend/start.sh +++ b/backend/start.sh @@ -3,6 +3,17 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) cd "$SCRIPT_DIR" || exit +# Add conditional Playwright browser installation +if [[ "${RAG_WEB_LOADER_ENGINE,,}" == "playwright" ]]; then + if [[ -z "${PLAYWRIGHT_WS_URI}" ]]; then + echo "Installing Playwright browsers..." + playwright install chromium + playwright install-deps chromium + fi + + python -c "import nltk; nltk.download('punkt_tab')" +fi + KEY_FILE=.webui_secret_key PORT="${PORT:-8080}" diff --git a/backend/start_windows.bat b/backend/start_windows.bat index 3e8c6b97c4..7049cd1b37 100644 --- a/backend/start_windows.bat +++ b/backend/start_windows.bat @@ -6,6 +6,17 @@ SETLOCAL ENABLEDELAYEDEXPANSION SET "SCRIPT_DIR=%~dp0" cd /d "%SCRIPT_DIR%" || exit /b +:: Add conditional Playwright browser installation +IF /I "%RAG_WEB_LOADER_ENGINE%" == "playwright" ( + IF "%PLAYWRIGHT_WS_URI%" == "" ( + echo Installing Playwright browsers... + playwright install chromium + playwright install-deps chromium + ) + + python -c "import nltk; nltk.download('punkt_tab')" +) + SET "KEY_FILE=.webui_secret_key" IF "%PORT%"=="" SET PORT=8080 IF "%HOST%"=="" SET HOST=0.0.0.0 diff --git a/docker-compose.playwright.yaml b/docker-compose.playwright.yaml new file mode 100644 index 0000000000..fe570bed01 --- /dev/null +++ b/docker-compose.playwright.yaml @@ -0,0 +1,10 @@ +services: + playwright: + image: mcr.microsoft.com/playwright:v1.49.1-noble # Version must match requirements.txt + container_name: playwright + command: npx -y playwright@1.49.1 run-server --port 3000 --host 0.0.0.0 + + open-webui: + environment: + - 'RAG_WEB_LOADER_ENGINE=playwright' + - 'PLAYWRIGHT_WS_URI=ws://playwright:3000' \ No newline at end of file diff --git a/package-lock.json b/package-lock.json index 4add48f972..e48c7930e1 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.5.12", + "version": "0.5.18", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.5.12", + "version": "0.5.18", "dependencies": { "@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-python": "^6.1.6", @@ -41,6 +41,7 @@ "i18next-resources-to-backend": "^1.2.0", "idb": "^7.1.1", "js-sha256": "^0.10.1", + "jspdf": "^3.0.0", "katex": "^0.16.21", "kokoro-js": "^1.1.1", "marked": "^9.1.0", @@ -63,6 +64,7 @@ "svelte-sonner": "^0.3.19", "tippy.js": "^6.3.7", "turndown": "^7.2.0", + "undici": "^7.3.0", "uuid": "^9.0.1", "vite-plugin-static-copy": "^2.2.0" }, @@ -134,9 +136,10 @@ } }, "node_modules/@babel/runtime": { - "version": "7.24.1", - "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.24.1.tgz", - "integrity": "sha512-+BIznRzyqBf+2wCTxcKE3wDjfGeCoVE61KSHGpkzqrLi8qxqFwBeUFyId2cxkTmm55fzDGnm0+yCxaxygrLUnQ==", + "version": "7.26.9", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.26.9.tgz", + "integrity": "sha512-aA63XwOkcl4xxQa3HjPMqOP6LiK0ZDv3mUPYEFXkpHbaFjtGggE1A61FjFzJnB+p7/oy2gA8E+rcBNl/zC1tMg==", + "license": "MIT", "dependencies": { "regenerator-runtime": "^0.14.0" }, @@ -3169,6 +3172,13 @@ "integrity": "sha512-Sk/uYFOBAB7mb74XcpizmH0KOR2Pv3D2Hmrh1Dmy5BmK3MpdSa5kqZcg6EKBdklU0bFXX9gCfzvpnyUehrPIuA==", "dev": true }, + "node_modules/@types/raf": { + "version": "3.4.3", + "resolved": "https://registry.npmjs.org/@types/raf/-/raf-3.4.3.tgz", + "integrity": "sha512-c4YAvMedbPZ5tEyxzQdMoOhhJ4RD3rngZIdwC2/qDN3d7JpEhB6fiBRKVY1lg5B7Wk+uPBjn5f39j1/2MY1oOw==", + "license": "MIT", + "optional": true + }, "node_modules/@types/resolve": { "version": "1.20.2", "resolved": "https://registry.npmjs.org/@types/resolve/-/resolve-1.20.2.tgz", @@ -3198,6 +3208,13 @@ "integrity": "sha512-MQ1AnmTLOncwEf9IVU+B2e4Hchrku5N67NkgcAHW0p3sdzPe0FNMANxEm6OJUzPniEQGkeT3OROLlCwZJLWFZA==", "dev": true }, + "node_modules/@types/trusted-types": { + "version": "2.0.7", + "resolved": "https://registry.npmjs.org/@types/trusted-types/-/trusted-types-2.0.7.tgz", + "integrity": "sha512-ScaPdn1dQczgbl0QFTeTOmVHFULt394XJgOQNoyVhZ6r2vLnMLJfBPd53SB52T/3G36VI1/g2MZaX0cwDuXsfw==", + "license": "MIT", + "optional": true + }, "node_modules/@types/unist": { "version": "2.0.10", "resolved": "https://registry.npmjs.org/@types/unist/-/unist-2.0.10.tgz", @@ -3793,6 +3810,18 @@ "node": ">= 4.0.0" } }, + "node_modules/atob": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/atob/-/atob-2.1.2.tgz", + "integrity": "sha512-Wm6ukoaOGJi/73p/cl2GvLjTI5JM1k/O14isD73YML8StrH/7/lRFgmg8nICZgD3bZZvjwCGxtMOD3wWNAu8cg==", + "license": "(MIT OR Apache-2.0)", + "bin": { + "atob": "bin/atob.js" + }, + "engines": { + "node": ">= 4.5.0" + } + }, "node_modules/aws-sign2": { "version": "0.7.0", "resolved": "https://registry.npmjs.org/aws-sign2/-/aws-sign2-0.7.0.tgz", @@ -3828,6 +3857,16 @@ "dev": true, "optional": true }, + "node_modules/base64-arraybuffer": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/base64-arraybuffer/-/base64-arraybuffer-1.0.2.tgz", + "integrity": "sha512-I3yl4r9QB5ZRY3XuJVEPfc2XhZO6YweFPI+UovAzn+8/hb3oJ6lnysaFcjVpkCPfVWFUDvoZ8kmVDP7WyRtYtQ==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.6.0" + } + }, "node_modules/base64-js": { "version": "1.5.1", "resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz", @@ -4032,6 +4071,18 @@ "node": "10.* || >= 12.*" } }, + "node_modules/btoa": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/btoa/-/btoa-1.2.1.tgz", + "integrity": "sha512-SB4/MIGlsiVkMcHmT+pSmIPoNDoHg+7cMzmt3Uxt628MTz2487DKSqK/fuhFBrkuqrYv5UCEnACpF4dTFNKc/g==", + "license": "(MIT OR Apache-2.0)", + "bin": { + "btoa": "bin/btoa.js" + }, + "engines": { + "node": ">= 0.4.0" + } + }, "node_modules/buffer": { "version": "6.0.3", "resolved": "https://registry.npmjs.org/buffer/-/buffer-6.0.3.tgz", @@ -4130,6 +4181,33 @@ "node": ">=6" } }, + "node_modules/canvg": { + "version": "3.0.10", + "resolved": "https://registry.npmjs.org/canvg/-/canvg-3.0.10.tgz", + "integrity": "sha512-qwR2FRNO9NlzTeKIPIKpnTY6fqwuYSequ8Ru8c0YkYU7U0oW+hLUvWadLvAu1Rl72OMNiFhoLu4f8eUjQ7l/+Q==", + "license": "MIT", + "optional": true, + "dependencies": { + "@babel/runtime": "^7.12.5", + "@types/raf": "^3.4.0", + "core-js": "^3.8.3", + "raf": "^3.4.1", + "regenerator-runtime": "^0.13.7", + "rgbcolor": "^1.0.1", + "stackblur-canvas": "^2.0.0", + "svg-pathdata": "^6.0.3" + }, + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/canvg/node_modules/regenerator-runtime": { + "version": "0.13.11", + "resolved": "https://registry.npmjs.org/regenerator-runtime/-/regenerator-runtime-0.13.11.tgz", + "integrity": "sha512-kY1AZVr2Ra+t+piVaJ4gxaFaReZVH40AKNo7UCX6W+dEwBo/2oZJzqfuN1qLq1oL45o56cPaTXELwrTh8Fpggg==", + "license": "MIT", + "optional": true + }, "node_modules/caseless": { "version": "0.12.0", "resolved": "https://registry.npmjs.org/caseless/-/caseless-0.12.0.tgz", @@ -4598,6 +4676,18 @@ "node": ">= 0.6" } }, + "node_modules/core-js": { + "version": "3.40.0", + "resolved": "https://registry.npmjs.org/core-js/-/core-js-3.40.0.tgz", + "integrity": "sha512-7vsMc/Lty6AGnn7uFpYT56QesI5D2Y/UkgKounk87OP9Z2H9Z8kj6jzcSGAxFmUtDOS0ntK6lbQz+Nsa0Jj6mQ==", + "hasInstallScript": true, + "license": "MIT", + "optional": true, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/core-js" + } + }, "node_modules/core-util-is": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.3.tgz", @@ -4642,6 +4732,16 @@ "node": ">= 8" } }, + "node_modules/css-line-break": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/css-line-break/-/css-line-break-2.1.0.tgz", + "integrity": "sha512-FHcKFCZcAha3LwfVBhCQbW2nCNbkZXn7KVUJcsT5/P8YmfsVja0FMPJr0B903j/E69HUphKiV9iQArX8SDYA4w==", + "license": "MIT", + "optional": true, + "dependencies": { + "utrie": "^1.0.2" + } + }, "node_modules/css-select": { "version": "5.1.0", "resolved": "https://registry.npmjs.org/css-select/-/css-select-5.1.0.tgz", @@ -6105,6 +6205,12 @@ "pend": "~1.2.0" } }, + "node_modules/fflate": { + "version": "0.8.2", + "resolved": "https://registry.npmjs.org/fflate/-/fflate-0.8.2.tgz", + "integrity": "sha512-cPJU47OaAoCbg0pBvzsgpTPhmhqI5eJjh/JIu8tPj5q+T7iLvW/JAYUqmE7KOB4R1ZyEhzBaIQpQpardBF5z8A==", + "license": "MIT" + }, "node_modules/figures": { "version": "3.2.0", "resolved": "https://registry.npmjs.org/figures/-/figures-3.2.0.tgz", @@ -6700,6 +6806,20 @@ "resolved": "https://registry.npmjs.org/html-escaper/-/html-escaper-3.0.3.tgz", "integrity": "sha512-RuMffC89BOWQoY0WKGpIhn5gX3iI54O6nRA0yC124NYVtzjmFWBIiFd8M0x+ZdX0P9R4lADg1mgP8C7PxGOWuQ==" }, + "node_modules/html2canvas": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/html2canvas/-/html2canvas-1.4.1.tgz", + "integrity": "sha512-fPU6BHNpsyIhr8yyMpTLLxAbkaK8ArIBcmZIRiBLiDhjeqvXolaEmDGmELFuX9I4xDcaKKcJl+TKZLqruBbmWA==", + "license": "MIT", + "optional": true, + "dependencies": { + "css-line-break": "^2.1.0", + "text-segmentation": "^1.0.3" + }, + "engines": { + "node": ">=8.0.0" + } + }, "node_modules/htmlparser2": { "version": "8.0.2", "resolved": "https://registry.npmjs.org/htmlparser2/-/htmlparser2-8.0.2.tgz", @@ -7212,6 +7332,34 @@ "graceful-fs": "^4.1.6" } }, + "node_modules/jspdf": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/jspdf/-/jspdf-3.0.0.tgz", + "integrity": "sha512-QvuQZvOI8CjfjVgtajdL0ihrDYif1cN5gXiF9lb9Pd9JOpmocvnNyFO9sdiJ/8RA5Bu8zyGOUjJLj5kiku16ug==", + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.26.0", + "atob": "^2.1.2", + "btoa": "^1.2.1", + "fflate": "^0.8.1" + }, + "optionalDependencies": { + "canvg": "^3.0.6", + "core-js": "^3.6.0", + "dompurify": "^3.2.4", + "html2canvas": "^1.0.0-rc.5" + } + }, + "node_modules/jspdf/node_modules/dompurify": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.2.4.tgz", + "integrity": "sha512-ysFSFEDVduQpyhzAob/kkuJjf5zWkZD8/A9ywSp1byueyuCfHamrCBa14/Oc2iiB0e51B+NpxSl5gmzn+Ms/mg==", + "license": "(MPL-2.0 OR Apache-2.0)", + "optional": true, + "optionalDependencies": { + "@types/trusted-types": "^2.0.7" + } + }, "node_modules/jsprim": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/jsprim/-/jsprim-2.0.2.tgz", @@ -9028,7 +9176,7 @@ "version": "2.1.0", "resolved": "https://registry.npmjs.org/performance-now/-/performance-now-2.1.0.tgz", "integrity": "sha512-7EAHlyLHI56VEIdK57uwHdHKIaAGbnXPiw0yWbarQZOKaKpvUIgW0jWRVLiatnM+XXlSwsanIBH/hzGMJulMow==", - "dev": true + "devOptional": true }, "node_modules/periscopic": { "version": "3.1.0", @@ -9754,6 +9902,16 @@ "rimraf": "bin.js" } }, + "node_modules/raf": { + "version": "3.4.1", + "resolved": "https://registry.npmjs.org/raf/-/raf-3.4.1.tgz", + "integrity": "sha512-Sq4CW4QhwOHE8ucn6J34MqtZCeWFP2aQSmrlroYgqAV1PjStIhJXxYuTgUIfkEk7zTLjmIjLmU5q+fbD1NnOJA==", + "license": "MIT", + "optional": true, + "dependencies": { + "performance-now": "^2.1.0" + } + }, "node_modules/react-is": { "version": "18.3.1", "resolved": "https://registry.npmjs.org/react-is/-/react-is-18.3.1.tgz", @@ -9893,6 +10051,16 @@ "integrity": "sha512-r5a3l5HzYlIC68TpmYKlxWjmOP6wiPJ1vWv2HeLhNsRZMrCkxeqxiHlQ21oXmQ4F3SiryXBHhAD7JZqvOJjFmg==", "dev": true }, + "node_modules/rgbcolor": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/rgbcolor/-/rgbcolor-1.0.1.tgz", + "integrity": "sha512-9aZLIrhRaD97sgVhtJOW6ckOEh6/GnvQtdVNfdZ6s67+3/XwLS9lBcQYzEEhYVeUowN7pRzMLsyGhK2i/xvWbw==", + "license": "MIT OR SEE LICENSE IN FEEL-FREE.md", + "optional": true, + "engines": { + "node": ">= 0.8.15" + } + }, "node_modules/rimraf": { "version": "3.0.2", "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-3.0.2.tgz", @@ -10784,6 +10952,16 @@ "integrity": "sha512-1XMJE5fQo1jGH6Y/7ebnwPOBEkIEnT4QF32d5R1+VXdXveM0IBMJt8zfaxX1P3QhVwrYe+576+jkANtSS2mBbw==", "dev": true }, + "node_modules/stackblur-canvas": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/stackblur-canvas/-/stackblur-canvas-2.7.0.tgz", + "integrity": "sha512-yf7OENo23AGJhBriGx0QivY5JP6Y1HbrrDI6WLt6C5auYZXlQrheoY8hD4ibekFKz1HOfE48Ww8kMWMnJD/zcQ==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=0.1.14" + } + }, "node_modules/std-env": { "version": "3.7.0", "resolved": "https://registry.npmjs.org/std-env/-/std-env-3.7.0.tgz", @@ -11165,6 +11343,16 @@ "@types/estree": "*" } }, + "node_modules/svg-pathdata": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/svg-pathdata/-/svg-pathdata-6.0.3.tgz", + "integrity": "sha512-qsjeeq5YjBZ5eMdFuUa4ZosMLxgr5RZ+F+Y1OrDhuOCEInRMA3x74XdBtggJcj9kOeInz0WE+LgCPDkZFlBYJw==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=12.0.0" + } + }, "node_modules/symlink-or-copy": { "version": "1.3.1", "resolved": "https://registry.npmjs.org/symlink-or-copy/-/symlink-or-copy-1.3.1.tgz", @@ -11257,6 +11445,16 @@ "streamx": "^2.12.5" } }, + "node_modules/text-segmentation": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/text-segmentation/-/text-segmentation-1.0.3.tgz", + "integrity": "sha512-iOiPUo/BGnZ6+54OsWxZidGCsdU8YbE4PSpdPinp7DeMtUJNJBoJ/ouUSTJjHkh1KntHaltHl/gDs2FC4i5+Nw==", + "license": "MIT", + "optional": true, + "dependencies": { + "utrie": "^1.0.2" + } + }, "node_modules/text-table": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/text-table/-/text-table-0.2.0.tgz", @@ -11528,6 +11726,15 @@ "node": "*" } }, + "node_modules/undici": { + "version": "7.3.0", + "resolved": "https://registry.npmjs.org/undici/-/undici-7.3.0.tgz", + "integrity": "sha512-Qy96NND4Dou5jKoSJ2gm8ax8AJM/Ey9o9mz7KN1bb9GP+G0l20Zw8afxTnY2f4b7hmhn/z8aC2kfArVQlAhFBw==", + "license": "MIT", + "engines": { + "node": ">=20.18.1" + } + }, "node_modules/undici-types": { "version": "5.26.5", "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", @@ -11587,6 +11794,16 @@ "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==", "dev": true }, + "node_modules/utrie": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/utrie/-/utrie-1.0.2.tgz", + "integrity": "sha512-1MLa5ouZiOmQzUbjbu9VmjLzn1QLXBhwpUa7kdLUQK+KQ5KA9I1vk5U4YHe/X2Ch7PYnJfWuWT+VbuxbGwljhw==", + "license": "MIT", + "optional": true, + "dependencies": { + "base64-arraybuffer": "^1.0.2" + } + }, "node_modules/uuid": { "version": "9.0.1", "resolved": "https://registry.npmjs.org/uuid/-/uuid-9.0.1.tgz", diff --git a/package.json b/package.json index 6460a0297e..1d2e86741a 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.5.12", + "version": "0.5.18", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", @@ -84,6 +84,7 @@ "i18next-resources-to-backend": "^1.2.0", "idb": "^7.1.1", "js-sha256": "^0.10.1", + "jspdf": "^3.0.0", "katex": "^0.16.21", "kokoro-js": "^1.1.1", "marked": "^9.1.0", @@ -106,6 +107,7 @@ "svelte-sonner": "^0.3.19", "tippy.js": "^6.3.7", "turndown": "^7.2.0", + "undici": "^7.3.0", "uuid": "^9.0.1", "vite-plugin-static-copy": "^2.2.0" }, diff --git a/postcss.config.js b/postcss.config.js index c5c819c31d..85b958cb59 100644 --- a/postcss.config.js +++ b/postcss.config.js @@ -1,5 +1,5 @@ export default { plugins: { - '@tailwindcss/postcss': {}, + '@tailwindcss/postcss': {} } }; diff --git a/pyproject.toml b/pyproject.toml index dac8bbf788..ccf4863462 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,11 +8,11 @@ license = { file = "LICENSE" } dependencies = [ "fastapi==0.115.7", "uvicorn[standard]==0.30.6", - "pydantic==2.9.2", + "pydantic==2.10.6", "python-multipart==0.0.18", "python-socketio==5.11.3", - "python-jose==3.3.0", + "python-jose==3.4.0", "passlib[bcrypt]==1.7.4", "requests==2.32.3", @@ -40,6 +40,9 @@ dependencies = [ "RestrictedPython==8.0", + "loguru==0.7.2", + "asgiref==3.8.1", + "openai", "anthropic", "google-generativeai==0.7.2", @@ -53,6 +56,7 @@ dependencies = [ "pymilvus==2.5.0", "qdrant-client~=1.12.0", "opensearch-py==2.8.0", + "playwright==1.49.1", "transformers", "sentence-transformers==3.3.1", @@ -65,7 +69,7 @@ dependencies = [ "pymdown-extensions==10.14.2", "docx2txt==0.8", "python-pptx==1.0.0", - "unstructured==0.16.11", + "unstructured==0.16.17", "nltk==3.9.1", "Markdown==3.7", "pypandoc==1.13", @@ -77,6 +81,7 @@ dependencies = [ "psutil", "sentencepiece", "soundfile==0.13.1", + "azure-ai-documentintelligence==1.0.0", "opencv-python-headless==4.11.0.86", "rapidocr-onnxruntime==1.3.24", @@ -108,7 +113,13 @@ dependencies = [ "googleapis-common-protos==1.63.2", "google-cloud-storage==2.19.0", + "azure-identity==1.20.0", + "azure-storage-blob==12.24.1", + "ldap3==2.9.1", + + "firecrawl-py==1.12.0", + "gcp-storage-emulator>=2024.8.3", ] readme = "README.md" diff --git a/run-compose.sh b/run-compose.sh index 21574e9599..4fafedc6f7 100755 --- a/run-compose.sh +++ b/run-compose.sh @@ -74,6 +74,7 @@ usage() { echo " --enable-api[port=PORT] Enable API and expose it on the specified port." echo " --webui[port=PORT] Set the port for the web user interface." echo " --data[folder=PATH] Bind mount for ollama data folder (by default will create the 'ollama' volume)." + echo " --playwright Enable Playwright support for web scraping." echo " --build Build the docker image before running the compose project." echo " --drop Drop the compose project." echo " -q, --quiet Run script in headless mode." @@ -100,6 +101,7 @@ webui_port=3000 headless=false build_image=false kill_compose=false +enable_playwright=false # Function to extract value from the parameter extract_value() { @@ -129,6 +131,9 @@ while [[ $# -gt 0 ]]; do value=$(extract_value "$key") data_dir=${value:-"./ollama-data"} ;; + --playwright) + enable_playwright=true + ;; --drop) kill_compose=true ;; @@ -182,6 +187,9 @@ else DEFAULT_COMPOSE_COMMAND+=" -f docker-compose.data.yaml" export OLLAMA_DATA_DIR=$data_dir # Set OLLAMA_DATA_DIR environment variable fi + if [[ $enable_playwright == true ]]; then + DEFAULT_COMPOSE_COMMAND+=" -f docker-compose.playwright.yaml" + fi if [[ -n $webui_port ]]; then export OPEN_WEBUI_PORT=$webui_port # Set OPEN_WEBUI_PORT environment variable fi @@ -201,6 +209,7 @@ echo -e " ${GREEN}${BOLD}GPU Count:${NC} ${OLLAMA_GPU_COUNT:-Not Enabled}" echo -e " ${GREEN}${BOLD}WebAPI Port:${NC} ${OLLAMA_WEBAPI_PORT:-Not Enabled}" echo -e " ${GREEN}${BOLD}Data Folder:${NC} ${data_dir:-Using ollama volume}" echo -e " ${GREEN}${BOLD}WebUI Port:${NC} $webui_port" +echo -e " ${GREEN}${BOLD}Playwright:${NC} ${enable_playwright:-false}" echo if [[ $headless == true ]]; then diff --git a/scripts/prepare-pyodide.js b/scripts/prepare-pyodide.js index 71f2a2cb29..70f3cf5c6c 100644 --- a/scripts/prepare-pyodide.js +++ b/scripts/prepare-pyodide.js @@ -16,8 +16,39 @@ const packages = [ ]; import { loadPyodide } from 'pyodide'; +import { setGlobalDispatcher, ProxyAgent } from 'undici'; import { writeFile, readFile, copyFile, readdir, rmdir } from 'fs/promises'; +/** + * Loading network proxy configurations from the environment variables. + * And the proxy config with lowercase name has the highest priority to use. + */ +function initNetworkProxyFromEnv() { + // we assume all subsequent requests in this script are HTTPS: + // https://cdn.jsdelivr.net + // https://pypi.org + // https://files.pythonhosted.org + const allProxy = process.env.all_proxy || process.env.ALL_PROXY; + const httpsProxy = process.env.https_proxy || process.env.HTTPS_PROXY; + const httpProxy = process.env.http_proxy || process.env.HTTP_PROXY; + const preferedProxy = httpsProxy || allProxy || httpProxy; + /** + * use only http(s) proxy because socks5 proxy is not supported currently: + * @see https://github.com/nodejs/undici/issues/2224 + */ + if (!preferedProxy || !preferedProxy.startsWith('http')) return; + let preferedProxyURL; + try { + preferedProxyURL = new URL(preferedProxy).toString(); + } catch { + console.warn(`Invalid network proxy URL: "${preferedProxy}"`); + return; + } + const dispatcher = new ProxyAgent({ uri: preferedProxyURL }); + setGlobalDispatcher(dispatcher); + console.log(`Initialized network proxy "${preferedProxy}" from env`); +} + async function downloadPackages() { console.log('Setting up pyodide + micropip'); @@ -84,5 +115,6 @@ async function copyPyodide() { } } +initNetworkProxyFromEnv(); await downloadPackages(); await copyPyodide(); diff --git a/src/app.css b/src/app.css index d547b91dac..8bdc6f1ade 100644 --- a/src/app.css +++ b/src/app.css @@ -55,11 +55,11 @@ math { } .markdown-prose { - @apply prose dark:prose-invert prose-blockquote:border-gray-100 prose-blockquote:dark:border-gray-800 prose-blockquote:border-l-2 prose-blockquote:not-italic prose-blockquote:font-normal prose-headings:font-semibold prose-hr:my-4 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; + @apply prose dark:prose-invert prose-blockquote:border-s-gray-100 prose-blockquote:dark:border-gray-800 prose-blockquote:border-s-2 prose-blockquote:not-italic prose-blockquote:font-normal prose-headings:font-semibold prose-hr:my-4 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; } .markdown-prose-xs { - @apply text-xs prose dark:prose-invert prose-blockquote:border-gray-100 prose-blockquote:dark:border-gray-800 prose-blockquote:border-l-2 prose-blockquote:not-italic prose-blockquote:font-normal prose-headings:font-semibold prose-hr:my-0 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; + @apply text-xs prose dark:prose-invert prose-blockquote:border-s-gray-100 prose-blockquote:dark:border-gray-800 prose-blockquote:border-s-2 prose-blockquote:not-italic prose-blockquote:font-normal prose-headings:font-semibold prose-hr:my-0 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; } .markdown a { @@ -101,7 +101,7 @@ li p { /* Dark theme scrollbar styles */ .dark ::-webkit-scrollbar-thumb { - background-color: rgba(33, 33, 33, 0.8); /* Darker color for dark theme */ + background-color: rgba(42, 42, 42, 0.8); /* Darker color for dark theme */ border-color: rgba(0, 0, 0, var(--tw-border-opacity)); } diff --git a/src/lib/apis/configs/index.ts b/src/lib/apis/configs/index.ts index d7f02564ce..f7f02c7405 100644 --- a/src/lib/apis/configs/index.ts +++ b/src/lib/apis/configs/index.ts @@ -115,10 +115,10 @@ export const setDirectConnectionsConfig = async (token: string, config: object) return res; }; -export const getCodeInterpreterConfig = async (token: string) => { +export const getCodeExecutionConfig = async (token: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/configs/code_interpreter`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/code_execution`, { method: 'GET', headers: { 'Content-Type': 'application/json', @@ -142,10 +142,10 @@ export const getCodeInterpreterConfig = async (token: string) => { return res; }; -export const setCodeInterpreterConfig = async (token: string, config: object) => { +export const setCodeExecutionConfig = async (token: string, config: object) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/configs/code_interpreter`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/code_execution`, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/src/lib/apis/retrieval/index.ts b/src/lib/apis/retrieval/index.ts index c35c37847b..31317fe0b9 100644 --- a/src/lib/apis/retrieval/index.ts +++ b/src/lib/apis/retrieval/index.ts @@ -32,9 +32,15 @@ type ChunkConfigForm = { chunk_overlap: number; }; +type DocumentIntelligenceConfigForm = { + key: string; + endpoint: string; +}; + type ContentExtractConfigForm = { engine: string; tika_server_url: string | null; + document_intelligence_config: DocumentIntelligenceConfigForm | null; }; type YoutubeConfigForm = { @@ -46,6 +52,7 @@ type YoutubeConfigForm = { type RAGConfigForm = { pdf_extract_images?: boolean; enable_google_drive_integration?: boolean; + enable_onedrive_integration?: boolean; chunk?: ChunkConfigForm; content_extraction?: ContentExtractConfigForm; web_loader_ssl_verification?: boolean; diff --git a/src/lib/apis/users/index.ts b/src/lib/apis/users/index.ts index b0efe39d27..f479b130c8 100644 --- a/src/lib/apis/users/index.ts +++ b/src/lib/apis/users/index.ts @@ -284,14 +284,16 @@ export const updateUserInfo = async (token: string, info: object) => { export const getAndUpdateUserLocation = async (token: string) => { const location = await getUserPosition().catch((err) => { - throw err; + console.log(err); + return null; }); if (location) { await updateUserInfo(token, { location: location }); return location; } else { - throw new Error('Failed to get user location'); + console.log('Failed to get user location'); + return null; } }; diff --git a/src/lib/apis/utils/index.ts b/src/lib/apis/utils/index.ts index 40fdbfcfa2..64db561249 100644 --- a/src/lib/apis/utils/index.ts +++ b/src/lib/apis/utils/index.ts @@ -1,12 +1,13 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; -export const getGravatarUrl = async (email: string) => { +export const getGravatarUrl = async (token: string, email: string) => { let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/utils/gravatar?email=${email}`, { method: 'GET', headers: { - 'Content-Type': 'application/json' + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` } }) .then(async (res) => { @@ -22,13 +23,14 @@ export const getGravatarUrl = async (email: string) => { return res; }; -export const formatPythonCode = async (code: string) => { +export const executeCode = async (token: string, code: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/utils/code/format`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/utils/code/execute`, { method: 'POST', headers: { - 'Content-Type': 'application/json' + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` }, body: JSON.stringify({ code: code @@ -55,13 +57,48 @@ export const formatPythonCode = async (code: string) => { return res; }; -export const downloadChatAsPDF = async (title: string, messages: object[]) => { +export const formatPythonCode = async (token: string, code: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/utils/code/format`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + code: code + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + + error = err; + if (err.detail) { + error = err.detail; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const downloadChatAsPDF = async (token: string, title: string, messages: object[]) => { let error = null; const blob = await fetch(`${WEBUI_API_BASE_URL}/utils/pdf`, { method: 'POST', headers: { - 'Content-Type': 'application/json' + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` }, body: JSON.stringify({ title: title, @@ -81,13 +118,14 @@ export const downloadChatAsPDF = async (title: string, messages: object[]) => { return blob; }; -export const getHTMLFromMarkdown = async (md: string) => { +export const getHTMLFromMarkdown = async (token: string, md: string) => { let error = null; const res = await fetch(`${WEBUI_API_BASE_URL}/utils/markdown`, { method: 'POST', headers: { - 'Content-Type': 'application/json' + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` }, body: JSON.stringify({ md: md diff --git a/src/lib/components/admin/Evaluations/Feedbacks.svelte b/src/lib/components/admin/Evaluations/Feedbacks.svelte index 07500c01c8..026755b8a3 100644 --- a/src/lib/components/admin/Evaluations/Feedbacks.svelte +++ b/src/lib/components/admin/Evaluations/Feedbacks.svelte @@ -131,7 +131,9 @@ -
+
{#if (feedbacks ?? []).length === 0}
{$i18n.t('No feedbacks found')} diff --git a/src/lib/components/admin/Evaluations/Leaderboard.svelte b/src/lib/components/admin/Evaluations/Leaderboard.svelte index 7412e799fb..e5d8a21662 100644 --- a/src/lib/components/admin/Evaluations/Leaderboard.svelte +++ b/src/lib/components/admin/Evaluations/Leaderboard.svelte @@ -300,7 +300,9 @@
-
+
{#if loadingLeaderboard}
diff --git a/src/lib/components/admin/Functions/FunctionEditor.svelte b/src/lib/components/admin/Functions/FunctionEditor.svelte index cbdec24257..6da2a83f45 100644 --- a/src/lib/components/admin/Functions/FunctionEditor.svelte +++ b/src/lib/components/admin/Functions/FunctionEditor.svelte @@ -1,8 +1,7 @@ + +
{ + await submitHandler(); + saveHandler(); + }} +> +
+ {#if config} +
+
+
{$i18n.t('General')}
+ +
+ +
+
+
{$i18n.t('Code Execution Engine')}
+
+ +
+
+ + {#if config.CODE_EXECUTION_ENGINE === 'jupyter'} +
+ {$i18n.t( + 'Warning: Jupyter execution enables arbitrary code execution, posing severe security risks—proceed with extreme caution.' + )} +
+ {/if} +
+ + {#if config.CODE_EXECUTION_ENGINE === 'jupyter'} +
+
+ {$i18n.t('Jupyter URL')} +
+ +
+
+ +
+
+
+ +
+
+
+ {$i18n.t('Jupyter Auth')} +
+ +
+ +
+
+ + {#if config.CODE_EXECUTION_JUPYTER_AUTH} +
+
+ {#if config.CODE_EXECUTION_JUPYTER_AUTH === 'password'} + + {:else} + + {/if} +
+
+ {/if} +
+ +
+
+ {$i18n.t('Code Execution Timeout')} +
+ +
+ + + +
+
+ {/if} +
+ +
+
{$i18n.t('Code Interpreter')}
+ +
+ +
+
+
+ {$i18n.t('Enable Code Interpreter')} +
+ + +
+
+ + {#if config.ENABLE_CODE_INTERPRETER} +
+
+
+ {$i18n.t('Code Interpreter Engine')} +
+
+ +
+
+ + {#if config.CODE_INTERPRETER_ENGINE === 'jupyter'} +
+ {$i18n.t( + 'Warning: Jupyter execution enables arbitrary code execution, posing severe security risks—proceed with extreme caution.' + )} +
+ {/if} +
+ + {#if config.CODE_INTERPRETER_ENGINE === 'jupyter'} +
+
+ {$i18n.t('Jupyter URL')} +
+ +
+
+ +
+
+
+ +
+
+
+ {$i18n.t('Jupyter Auth')} +
+ +
+ +
+
+ + {#if config.CODE_INTERPRETER_JUPYTER_AUTH} +
+
+ {#if config.CODE_INTERPRETER_JUPYTER_AUTH === 'password'} + + {:else} + + {/if} +
+
+ {/if} +
+ +
+
+ {$i18n.t('Code Execution Timeout')} +
+ +
+ + + +
+
+ {/if} + +
+ +
+
+
+ {$i18n.t('Code Interpreter Prompt Template')} +
+ + +