diff --git a/.env.example b/.env.example
index c38bf88bfb..35ea12a885 100644
--- a/.env.example
+++ b/.env.example
@@ -7,6 +7,15 @@ OPENAI_API_KEY=''
# AUTOMATIC1111_BASE_URL="http://localhost:7860"
+# For production, you should only need one host as
+# fastapi serves the svelte-kit built frontend and backend from the same host and port.
+# To test with CORS locally, you can set something like
+# CORS_ALLOW_ORIGIN='http://localhost:5173;http://localhost:8080'
+CORS_ALLOW_ORIGIN='*'
+
+# For production you should set this to match the proxy configuration (127.0.0.1)
+FORWARDED_ALLOW_IPS='*'
+
# DO NOT TRACK
SCARF_NO_ANALYTICS=true
DO_NOT_TRACK=true
diff --git a/.gitattributes b/.gitattributes
index 526c8a38d4..bf368a4c6c 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1 +1,49 @@
-*.sh text eol=lf
\ No newline at end of file
+# TypeScript
+*.ts text eol=lf
+*.tsx text eol=lf
+
+# JavaScript
+*.js text eol=lf
+*.jsx text eol=lf
+*.mjs text eol=lf
+*.cjs text eol=lf
+
+# Svelte
+*.svelte text eol=lf
+
+# HTML/CSS
+*.html text eol=lf
+*.css text eol=lf
+*.scss text eol=lf
+*.less text eol=lf
+
+# Config files and JSON
+*.json text eol=lf
+*.jsonc text eol=lf
+*.yml text eol=lf
+*.yaml text eol=lf
+*.toml text eol=lf
+
+# Shell scripts
+*.sh text eol=lf
+
+# Markdown & docs
+*.md text eol=lf
+*.mdx text eol=lf
+*.txt text eol=lf
+
+# Git-related
+.gitattributes text eol=lf
+.gitignore text eol=lf
+
+# Prettier and other dotfiles
+.prettierrc text eol=lf
+.prettierignore text eol=lf
+.eslintrc text eol=lf
+.eslintignore text eol=lf
+.stylelintrc text eol=lf
+.editorconfig text eol=lf
+
+# Misc
+*.env text eol=lf
+*.lock text eol=lf
\ No newline at end of file
diff --git a/.github/workflows/docker-build.yaml b/.github/workflows/docker-build.yaml
index e61a69f33a..821ffb7206 100644
--- a/.github/workflows/docker-build.yaml
+++ b/.github/workflows/docker-build.yaml
@@ -14,16 +14,18 @@ env:
jobs:
build-main-image:
- runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
+ runs-on: ${{ matrix.runner }}
permissions:
contents: read
packages: write
strategy:
fail-fast: false
matrix:
- platform:
- - linux/amd64
- - linux/arm64
+ include:
+ - platform: linux/amd64
+ runner: ubuntu-latest
+ - platform: linux/arm64
+ runner: ubuntu-24.04-arm
steps:
# GitHub Packages requires the entire repository name to be in lowercase
@@ -111,16 +113,18 @@ jobs:
retention-days: 1
build-cuda-image:
- runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
+ runs-on: ${{ matrix.runner }}
permissions:
contents: read
packages: write
strategy:
fail-fast: false
matrix:
- platform:
- - linux/amd64
- - linux/arm64
+ include:
+ - platform: linux/amd64
+ runner: ubuntu-latest
+ - platform: linux/arm64
+ runner: ubuntu-24.04-arm
steps:
# GitHub Packages requires the entire repository name to be in lowercase
@@ -210,17 +214,122 @@ jobs:
if-no-files-found: error
retention-days: 1
- build-ollama-image:
- runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
+ build-cuda126-image:
+ runs-on: ${{ matrix.runner }}
permissions:
contents: read
packages: write
strategy:
fail-fast: false
matrix:
- platform:
- - linux/amd64
- - linux/arm64
+ include:
+ - platform: linux/amd64
+ runner: ubuntu-latest
+ - platform: linux/arm64
+ runner: ubuntu-24.04-arm
+
+ steps:
+ # GitHub Packages requires the entire repository name to be in lowercase
+ # although the repository owner has a lowercase username, this prevents some people from running actions after forking
+ - name: Set repository and image name to lowercase
+ run: |
+ echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
+ echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
+ env:
+ IMAGE_NAME: '${{ github.repository }}'
+
+ - name: Prepare
+ run: |
+ platform=${{ matrix.platform }}
+ echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
+
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - name: Set up QEMU
+ uses: docker/setup-qemu-action@v3
+
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3
+
+ - name: Log in to the Container registry
+ uses: docker/login-action@v3
+ with:
+ registry: ${{ env.REGISTRY }}
+ username: ${{ github.actor }}
+ password: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Extract metadata for Docker images (cuda126 tag)
+ id: meta
+ uses: docker/metadata-action@v5
+ with:
+ images: ${{ env.FULL_IMAGE_NAME }}
+ tags: |
+ type=ref,event=branch
+ type=ref,event=tag
+ type=sha,prefix=git-
+ type=semver,pattern={{version}}
+ type=semver,pattern={{major}}.{{minor}}
+ type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=cuda126
+ flavor: |
+ latest=${{ github.ref == 'refs/heads/main' }}
+ suffix=-cuda126,onlatest=true
+
+ - name: Extract metadata for Docker cache
+ id: cache-meta
+ uses: docker/metadata-action@v5
+ with:
+ images: ${{ env.FULL_IMAGE_NAME }}
+ tags: |
+ type=ref,event=branch
+ ${{ github.ref_type == 'tag' && 'type=raw,value=main' || '' }}
+ flavor: |
+ prefix=cache-cuda126-${{ matrix.platform }}-
+ latest=false
+
+ - name: Build Docker image (cuda126)
+ uses: docker/build-push-action@v5
+ id: build
+ with:
+ context: .
+ push: true
+ platforms: ${{ matrix.platform }}
+ labels: ${{ steps.meta.outputs.labels }}
+ outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
+ cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }}
+ cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max
+ build-args: |
+ BUILD_HASH=${{ github.sha }}
+ USE_CUDA=true
+ USE_CUDA_VER=cu126
+
+ - name: Export digest
+ run: |
+ mkdir -p /tmp/digests
+ digest="${{ steps.build.outputs.digest }}"
+ touch "/tmp/digests/${digest#sha256:}"
+
+ - name: Upload digest
+ uses: actions/upload-artifact@v4
+ with:
+ name: digests-cuda126-${{ env.PLATFORM_PAIR }}
+ path: /tmp/digests/*
+ if-no-files-found: error
+ retention-days: 1
+
+ build-ollama-image:
+ runs-on: ${{ matrix.runner }}
+ permissions:
+ contents: read
+ packages: write
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - platform: linux/amd64
+ runner: ubuntu-latest
+ - platform: linux/arm64
+ runner: ubuntu-24.04-arm
steps:
# GitHub Packages requires the entire repository name to be in lowercase
@@ -420,6 +529,62 @@ jobs:
run: |
docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
+ merge-cuda126-images:
+ runs-on: ubuntu-latest
+ needs: [build-cuda126-image]
+ steps:
+ # GitHub Packages requires the entire repository name to be in lowercase
+ # although the repository owner has a lowercase username, this prevents some people from running actions after forking
+ - name: Set repository and image name to lowercase
+ run: |
+ echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
+ echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
+ env:
+ IMAGE_NAME: '${{ github.repository }}'
+
+ - name: Download digests
+ uses: actions/download-artifact@v4
+ with:
+ pattern: digests-cuda126-*
+ path: /tmp/digests
+ merge-multiple: true
+
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3
+
+ - name: Log in to the Container registry
+ uses: docker/login-action@v3
+ with:
+ registry: ${{ env.REGISTRY }}
+ username: ${{ github.actor }}
+ password: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Extract metadata for Docker images (default latest tag)
+ id: meta
+ uses: docker/metadata-action@v5
+ with:
+ images: ${{ env.FULL_IMAGE_NAME }}
+ tags: |
+ type=ref,event=branch
+ type=ref,event=tag
+ type=sha,prefix=git-
+ type=semver,pattern={{version}}
+ type=semver,pattern={{major}}.{{minor}}
+ type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=cuda126
+ flavor: |
+ latest=${{ github.ref == 'refs/heads/main' }}
+ suffix=-cuda126,onlatest=true
+
+ - name: Create manifest list and push
+ working-directory: /tmp/digests
+ run: |
+ docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
+ $(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *)
+
+ - name: Inspect image
+ run: |
+ docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
+
merge-ollama-images:
runs-on: ubuntu-latest
needs: [build-ollama-image]
diff --git a/.prettierrc b/.prettierrc
index a77fddea90..22558729f4 100644
--- a/.prettierrc
+++ b/.prettierrc
@@ -5,5 +5,6 @@
"printWidth": 100,
"plugins": ["prettier-plugin-svelte"],
"pluginSearchDirs": ["."],
- "overrides": [{ "files": "*.svelte", "options": { "parser": "svelte" } }]
+ "overrides": [{ "files": "*.svelte", "options": { "parser": "svelte" } }],
+ "endOfLine": "lf"
}
diff --git a/README.md b/README.md
index 2ad208c3f7..dce1022fc6 100644
--- a/README.md
+++ b/README.md
@@ -79,7 +79,7 @@ Want to learn more about Open WebUI's features? Check out our [Open WebUI docume
-
+
|
@@ -181,6 +181,8 @@ After installation, you can access Open WebUI at [http://localhost:3000](http://
We offer various installation alternatives, including non-Docker native installation methods, Docker Compose, Kustomize, and Helm. Visit our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/) or join our [Discord community](https://discord.gg/5rJgQTnV4s) for comprehensive guidance.
+Look at the [Local Development Guide](https://docs.openwebui.com/getting-started/advanced-topics/development) for instructions on setting up a local development environment.
+
### Troubleshooting
Encountering connection issues? Our [Open WebUI Documentation](https://docs.openwebui.com/troubleshooting/) has got you covered. For further assistance and to join our vibrant community, visit the [Open WebUI Discord](https://discord.gg/5rJgQTnV4s).
diff --git a/backend/dev.sh b/backend/dev.sh
index 5449ab7777..22d9527656 100755
--- a/backend/dev.sh
+++ b/backend/dev.sh
@@ -1,2 +1,2 @@
PORT="${PORT:-8080}"
-uvicorn open_webui.main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload
\ No newline at end of file
+uvicorn open_webui.main:app --port $PORT --host 0.0.0.0 --reload
\ No newline at end of file
diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index 0f49483610..79d2db84bd 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -901,9 +901,7 @@ TOOL_SERVER_CONNECTIONS = PersistentConfig(
####################################
-WEBUI_URL = PersistentConfig(
- "WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", "http://localhost:3000")
-)
+WEBUI_URL = PersistentConfig("WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", ""))
ENABLE_SIGNUP = PersistentConfig(
@@ -1247,12 +1245,6 @@ if THREAD_POOL_SIZE is not None and isinstance(THREAD_POOL_SIZE, str):
THREAD_POOL_SIZE = None
-def validate_cors_origins(origins):
- for origin in origins:
- if origin != "*":
- validate_cors_origin(origin)
-
-
def validate_cors_origin(origin):
parsed_url = urlparse(origin)
@@ -1272,16 +1264,17 @@ def validate_cors_origin(origin):
# To test CORS_ALLOW_ORIGIN locally, you can set something like
# CORS_ALLOW_ORIGIN=http://localhost:5173;http://localhost:8080
# in your .env file depending on your frontend port, 5173 in this case.
-CORS_ALLOW_ORIGIN = os.environ.get(
- "CORS_ALLOW_ORIGIN", "*;http://localhost:5173;http://localhost:8080"
-).split(";")
+CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";")
-if "*" in CORS_ALLOW_ORIGIN:
+if CORS_ALLOW_ORIGIN == ["*"]:
log.warning(
"\n\nWARNING: CORS_ALLOW_ORIGIN IS SET TO '*' - NOT RECOMMENDED FOR PRODUCTION DEPLOYMENTS.\n"
)
-
-validate_cors_origins(CORS_ALLOW_ORIGIN)
+else:
+ # You have to pick between a single wildcard or a list of origins.
+ # Doing both will result in CORS errors in the browser.
+ for origin in CORS_ALLOW_ORIGIN:
+ validate_cors_origin(origin)
class BannerModel(BaseModel):
@@ -1413,6 +1406,35 @@ Strictly return in JSON format:
{{MESSAGES:END:6}}
"""
+
+FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
+ "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE",
+ "task.follow_up.prompt_template",
+ os.environ.get("FOLLOW_UP_GENERATION_PROMPT_TEMPLATE", ""),
+)
+
+DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = """### Task:
+Suggest 3-5 relevant follow-up questions or prompts that the user might naturally ask next in this conversation as a **user**, based on the chat history, to help continue or deepen the discussion.
+### Guidelines:
+- Write all follow-up questions from the user’s point of view, directed to the assistant.
+- Make questions concise, clear, and directly related to the discussed topic(s).
+- Only suggest follow-ups that make sense given the chat content and do not repeat what was already covered.
+- If the conversation is very short or not specific, suggest more general (but relevant) follow-ups the user might ask.
+- Use the conversation's primary language; default to English if multilingual.
+- Response must be a JSON array of strings, no extra text or formatting.
+### Output:
+JSON format: { "follow_ups": ["Question 1?", "Question 2?", "Question 3?"] }
+### Chat History:
+
+{{MESSAGES:END:6}}
+"""
+
+ENABLE_FOLLOW_UP_GENERATION = PersistentConfig(
+ "ENABLE_FOLLOW_UP_GENERATION",
+ "task.follow_up.enable",
+ os.environ.get("ENABLE_FOLLOW_UP_GENERATION", "True").lower() == "true",
+)
+
ENABLE_TAGS_GENERATION = PersistentConfig(
"ENABLE_TAGS_GENERATION",
"task.tags.enable",
@@ -1945,6 +1967,40 @@ DOCLING_DO_PICTURE_DESCRIPTION = PersistentConfig(
os.getenv("DOCLING_DO_PICTURE_DESCRIPTION", "False").lower() == "true",
)
+DOCLING_PICTURE_DESCRIPTION_MODE = PersistentConfig(
+ "DOCLING_PICTURE_DESCRIPTION_MODE",
+ "rag.docling_picture_description_mode",
+ os.getenv("DOCLING_PICTURE_DESCRIPTION_MODE", ""),
+)
+
+
+docling_picture_description_local = os.getenv("DOCLING_PICTURE_DESCRIPTION_LOCAL", "")
+try:
+ docling_picture_description_local = json.loads(docling_picture_description_local)
+except json.JSONDecodeError:
+ docling_picture_description_local = {}
+
+
+DOCLING_PICTURE_DESCRIPTION_LOCAL = PersistentConfig(
+ "DOCLING_PICTURE_DESCRIPTION_LOCAL",
+ "rag.docling_picture_description_local",
+ docling_picture_description_local,
+)
+
+doclign_picture_description_api = os.getenv("DOCLING_PICTURE_DESCRIPTION_API", "")
+try:
+ doclign_picture_description_api = json.loads(doclign_picture_description_api)
+except json.JSONDecodeError:
+ doclign_picture_description_api = {}
+
+
+DOCLING_PICTURE_DESCRIPTION_API = PersistentConfig(
+ "DOCLING_PICTURE_DESCRIPTION_API",
+ "rag.docling_picture_description_api",
+ doclign_picture_description_api,
+)
+
+
DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig(
"DOCUMENT_INTELLIGENCE_ENDPOINT",
"rag.document_intelligence_endpoint",
@@ -2444,6 +2500,18 @@ PERPLEXITY_API_KEY = PersistentConfig(
os.getenv("PERPLEXITY_API_KEY", ""),
)
+PERPLEXITY_MODEL = PersistentConfig(
+ "PERPLEXITY_MODEL",
+ "rag.web.search.perplexity_model",
+ os.getenv("PERPLEXITY_MODEL", "sonar"),
+)
+
+PERPLEXITY_SEARCH_CONTEXT_USAGE = PersistentConfig(
+ "PERPLEXITY_SEARCH_CONTEXT_USAGE",
+ "rag.web.search.perplexity_search_context_usage",
+ os.getenv("PERPLEXITY_SEARCH_CONTEXT_USAGE", "medium"),
+)
+
SOUGOU_API_SID = PersistentConfig(
"SOUGOU_API_SID",
"rag.web.search.sougou_api_sid",
diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py
index 95c54a0d27..59ee6aaacb 100644
--- a/backend/open_webui/constants.py
+++ b/backend/open_webui/constants.py
@@ -111,6 +111,7 @@ class TASKS(str, Enum):
DEFAULT = lambda task="": f"{task if task else 'generation'}"
TITLE_GENERATION = "title_generation"
+ FOLLOW_UP_GENERATION = "follow_up_generation"
TAGS_GENERATION = "tags_generation"
EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "query_generation"
diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py
index fcfccaedf5..7601748376 100644
--- a/backend/open_webui/env.py
+++ b/backend/open_webui/env.py
@@ -5,6 +5,7 @@ import os
import pkgutil
import sys
import shutil
+from uuid import uuid4
from pathlib import Path
import markdown
@@ -130,6 +131,7 @@ else:
PACKAGE_DATA = {"version": "0.0.0"}
VERSION = PACKAGE_DATA["version"]
+INSTANCE_ID = os.environ.get("INSTANCE_ID", str(uuid4()))
# Function to parse each section
diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py
index 20fabb2dc7..6eb5c1bbdb 100644
--- a/backend/open_webui/functions.py
+++ b/backend/open_webui/functions.py
@@ -25,6 +25,7 @@ from open_webui.socket.main import (
)
+from open_webui.models.users import UserModel
from open_webui.models.functions import Functions
from open_webui.models.models import Models
@@ -227,12 +228,7 @@ async def generate_function_chat_completion(
"__task__": __task__,
"__task_body__": __task_body__,
"__files__": files,
- "__user__": {
- "id": user.id,
- "email": user.email,
- "name": user.name,
- "role": user.role,
- },
+ "__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__request__": request,
}
diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py
index 6bdcf4957a..1577b01707 100644
--- a/backend/open_webui/main.py
+++ b/backend/open_webui/main.py
@@ -8,6 +8,8 @@ import shutil
import sys
import time
import random
+from uuid import uuid4
+
from contextlib import asynccontextmanager
from urllib.parse import urlencode, parse_qs, urlparse
@@ -19,6 +21,7 @@ from aiocache import cached
import aiohttp
import anyio.to_thread
import requests
+from redis import Redis
from fastapi import (
@@ -37,7 +40,7 @@ from fastapi import (
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import JSONResponse, RedirectResponse
+from fastapi.responses import FileResponse, JSONResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from starlette_compress import CompressMiddleware
@@ -231,6 +234,9 @@ from open_webui.config import (
DOCLING_OCR_ENGINE,
DOCLING_OCR_LANG,
DOCLING_DO_PICTURE_DESCRIPTION,
+ DOCLING_PICTURE_DESCRIPTION_MODE,
+ DOCLING_PICTURE_DESCRIPTION_LOCAL,
+ DOCLING_PICTURE_DESCRIPTION_API,
DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY,
MISTRAL_OCR_API_KEY,
@@ -268,6 +274,8 @@ from open_webui.config import (
BRAVE_SEARCH_API_KEY,
EXA_API_KEY,
PERPLEXITY_API_KEY,
+ PERPLEXITY_MODEL,
+ PERPLEXITY_SEARCH_CONTEXT_USAGE,
SOUGOU_API_SID,
SOUGOU_API_SK,
KAGI_SEARCH_API_KEY,
@@ -359,10 +367,12 @@ from open_webui.config import (
TASK_MODEL_EXTERNAL,
ENABLE_TAGS_GENERATION,
ENABLE_TITLE_GENERATION,
+ ENABLE_FOLLOW_UP_GENERATION,
ENABLE_SEARCH_QUERY_GENERATION,
ENABLE_RETRIEVAL_QUERY_GENERATION,
ENABLE_AUTOCOMPLETE_GENERATION,
TITLE_GENERATION_PROMPT_TEMPLATE,
+ FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
TAGS_GENERATION_PROMPT_TEMPLATE,
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
@@ -384,6 +394,7 @@ from open_webui.env import (
SAFE_MODE,
SRC_LOG_LEVELS,
VERSION,
+ INSTANCE_ID,
WEBUI_BUILD_HASH,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
@@ -411,6 +422,7 @@ from open_webui.utils.chat import (
chat_completed as chat_completed_handler,
chat_action as chat_action_handler,
)
+from open_webui.utils.embeddings import generate_embeddings
from open_webui.utils.middleware import process_chat_payload, process_chat_response
from open_webui.utils.access_control import has_access
@@ -424,8 +436,10 @@ from open_webui.utils.auth import (
from open_webui.utils.plugin import install_tool_and_function_dependencies
from open_webui.utils.oauth import OAuthManager
from open_webui.utils.security_headers import SecurityHeadersMiddleware
+from open_webui.utils.redis import get_redis_connection
from open_webui.tasks import (
+ redis_task_command_listener,
list_task_ids_by_chat_id,
stop_task,
list_tasks,
@@ -477,7 +491,9 @@ https://github.com/open-webui/open-webui
@asynccontextmanager
async def lifespan(app: FastAPI):
+ app.state.instance_id = INSTANCE_ID
start_logger()
+
if RESET_CONFIG_ON_START:
reset_config()
@@ -489,6 +505,18 @@ async def lifespan(app: FastAPI):
log.info("Installing external dependencies of functions and tools...")
install_tool_and_function_dependencies()
+ app.state.redis = get_redis_connection(
+ redis_url=REDIS_URL,
+ redis_sentinels=get_sentinels_from_env(
+ REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
+ ),
+ )
+
+ if isinstance(app.state.redis, Redis):
+ app.state.redis_task_command_listener = asyncio.create_task(
+ redis_task_command_listener(app)
+ )
+
if THREAD_POOL_SIZE and THREAD_POOL_SIZE > 0:
limiter = anyio.to_thread.current_default_thread_limiter()
limiter.total_tokens = THREAD_POOL_SIZE
@@ -497,6 +525,9 @@ async def lifespan(app: FastAPI):
yield
+ if hasattr(app.state, "redis_task_command_listener"):
+ app.state.redis_task_command_listener.cancel()
+
app = FastAPI(
title="Open WebUI",
@@ -508,10 +539,12 @@ app = FastAPI(
oauth_manager = OAuthManager(app)
+app.state.instance_id = None
app.state.config = AppConfig(
redis_url=REDIS_URL,
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
)
+app.state.redis = None
app.state.WEBUI_NAME = WEBUI_NAME
app.state.LICENSE_METADATA = None
@@ -696,6 +729,9 @@ app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
app.state.config.DOCLING_OCR_ENGINE = DOCLING_OCR_ENGINE
app.state.config.DOCLING_OCR_LANG = DOCLING_OCR_LANG
app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = DOCLING_DO_PICTURE_DESCRIPTION
+app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE = DOCLING_PICTURE_DESCRIPTION_MODE
+app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL = DOCLING_PICTURE_DESCRIPTION_LOCAL
+app.state.config.DOCLING_PICTURE_DESCRIPTION_API = DOCLING_PICTURE_DESCRIPTION_API
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
app.state.config.MISTRAL_OCR_API_KEY = MISTRAL_OCR_API_KEY
@@ -771,6 +807,8 @@ app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
app.state.config.EXA_API_KEY = EXA_API_KEY
app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
+app.state.config.PERPLEXITY_MODEL = PERPLEXITY_MODEL
+app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = PERPLEXITY_SEARCH_CONTEXT_USAGE
app.state.config.SOUGOU_API_SID = SOUGOU_API_SID
app.state.config.SOUGOU_API_SK = SOUGOU_API_SK
app.state.config.EXTERNAL_WEB_SEARCH_URL = EXTERNAL_WEB_SEARCH_URL
@@ -959,6 +997,7 @@ app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENE
app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION
app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
app.state.config.ENABLE_TITLE_GENERATION = ENABLE_TITLE_GENERATION
+app.state.config.ENABLE_FOLLOW_UP_GENERATION = ENABLE_FOLLOW_UP_GENERATION
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
@@ -966,6 +1005,9 @@ app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLA
app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
)
+app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
+ FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
+)
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
@@ -1197,6 +1239,37 @@ async def get_base_models(request: Request, user=Depends(get_admin_user)):
return {"data": models}
+##################################
+# Embeddings
+##################################
+
+
+@app.post("/api/embeddings")
+async def embeddings(
+ request: Request, form_data: dict, user=Depends(get_verified_user)
+):
+ """
+ OpenAI-compatible embeddings endpoint.
+
+ This handler:
+ - Performs user/model checks and dispatches to the correct backend.
+ - Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider.
+
+ Args:
+ request (Request): Request context.
+ form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]})
+ user (UserModel): Authenticated user.
+
+ Returns:
+ dict: OpenAI-compatible embeddings response.
+ """
+ # Make sure models are loaded in app state
+ if not request.app.state.MODELS:
+ await get_all_models(request, user=user)
+ # Use generic dispatcher in utils.embeddings
+ return await generate_embeddings(request, form_data, user)
+
+
@app.post("/api/chat/completions")
async def chat_completion(
request: Request,
@@ -1338,26 +1411,30 @@ async def chat_action(
@app.post("/api/tasks/stop/{task_id}")
-async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
+async def stop_task_endpoint(
+ request: Request, task_id: str, user=Depends(get_verified_user)
+):
try:
- result = await stop_task(task_id)
+ result = await stop_task(request, task_id)
return result
except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
@app.get("/api/tasks")
-async def list_tasks_endpoint(user=Depends(get_verified_user)):
- return {"tasks": list_tasks()}
+async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)):
+ return {"tasks": list_tasks(request)}
@app.get("/api/tasks/chat/{chat_id}")
-async def list_tasks_by_chat_id_endpoint(chat_id: str, user=Depends(get_verified_user)):
+async def list_tasks_by_chat_id_endpoint(
+ request: Request, chat_id: str, user=Depends(get_verified_user)
+):
chat = Chats.get_chat_by_id(chat_id)
if chat is None or chat.user_id != user.id:
return {"task_ids": []}
- task_ids = list_task_ids_by_chat_id(chat_id)
+ task_ids = list_task_ids_by_chat_id(request, chat_id)
print(f"Task IDs for chat {chat_id}: {task_ids}")
return {"task_ids": task_ids}
@@ -1628,7 +1705,20 @@ async def healthcheck_with_db():
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
-app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
+
+
+@app.get("/cache/{path:path}")
+async def serve_cache_file(
+ path: str,
+ user=Depends(get_verified_user),
+):
+ file_path = os.path.abspath(os.path.join(CACHE_DIR, path))
+ # prevent path traversal
+ if not file_path.startswith(os.path.abspath(CACHE_DIR)):
+ raise HTTPException(status_code=404, detail="File not found")
+ if not os.path.isfile(file_path):
+ raise HTTPException(status_code=404, detail="File not found")
+ return FileResponse(file_path)
def swagger_ui_html(*args, **kwargs):
diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py
index 3222aa27a6..a5dd9467bc 100644
--- a/backend/open_webui/models/users.py
+++ b/backend/open_webui/models/users.py
@@ -95,6 +95,7 @@ class UserRoleUpdateForm(BaseModel):
class UserUpdateForm(BaseModel):
+ role: str
name: str
email: str
profile_image_url: str
diff --git a/backend/open_webui/retrieval/loaders/main.py b/backend/open_webui/retrieval/loaders/main.py
index 0d0ff851b7..fd1f606761 100644
--- a/backend/open_webui/retrieval/loaders/main.py
+++ b/backend/open_webui/retrieval/loaders/main.py
@@ -2,6 +2,7 @@ import requests
import logging
import ftfy
import sys
+import json
from langchain_community.document_loaders import (
AzureAIDocumentIntelligenceLoader,
@@ -76,7 +77,6 @@ known_source_ext = [
"swift",
"vue",
"svelte",
- "msg",
"ex",
"exs",
"erl",
@@ -147,17 +147,32 @@ class DoclingLoader:
)
}
- params = {
- "image_export_mode": "placeholder",
- "table_mode": "accurate",
- }
+ params = {"image_export_mode": "placeholder", "table_mode": "accurate"}
if self.params:
- if self.params.get("do_picture_classification"):
- params["do_picture_classification"] = self.params.get(
- "do_picture_classification"
+ if self.params.get("do_picture_description"):
+ params["do_picture_description"] = self.params.get(
+ "do_picture_description"
)
+ picture_description_mode = self.params.get(
+ "picture_description_mode", ""
+ ).lower()
+
+ if picture_description_mode == "local" and self.params.get(
+ "picture_description_local", {}
+ ):
+ params["picture_description_local"] = self.params.get(
+ "picture_description_local", {}
+ )
+
+ elif picture_description_mode == "api" and self.params.get(
+ "picture_description_api", {}
+ ):
+ params["picture_description_api"] = self.params.get(
+ "picture_description_api", {}
+ )
+
if self.params.get("ocr_engine") and self.params.get("ocr_lang"):
params["ocr_engine"] = self.params.get("ocr_engine")
params["ocr_lang"] = [
@@ -285,17 +300,20 @@ class Loader:
if self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
+ # Build params for DoclingLoader
+ params = self.kwargs.get("DOCLING_PARAMS", {})
+ if not isinstance(params, dict):
+ try:
+ params = json.loads(params)
+ except json.JSONDecodeError:
+ log.error("Invalid DOCLING_PARAMS format, expected JSON object")
+ params = {}
+
loader = DoclingLoader(
url=self.kwargs.get("DOCLING_SERVER_URL"),
file_path=file_path,
mime_type=file_content_type,
- params={
- "ocr_engine": self.kwargs.get("DOCLING_OCR_ENGINE"),
- "ocr_lang": self.kwargs.get("DOCLING_OCR_LANG"),
- "do_picture_classification": self.kwargs.get(
- "DOCLING_DO_PICTURE_DESCRIPTION"
- ),
- },
+ params=params,
)
elif (
self.engine == "document_intelligence"
diff --git a/backend/open_webui/retrieval/loaders/mistral.py b/backend/open_webui/retrieval/loaders/mistral.py
index 67641d0509..b00e9d7ce5 100644
--- a/backend/open_webui/retrieval/loaders/mistral.py
+++ b/backend/open_webui/retrieval/loaders/mistral.py
@@ -20,6 +20,14 @@ class MistralLoader:
"""
Enhanced Mistral OCR loader with both sync and async support.
Loads documents by processing them through the Mistral OCR API.
+
+ Performance Optimizations:
+ - Differentiated timeouts for different operations
+ - Intelligent retry logic with exponential backoff
+ - Memory-efficient file streaming for large files
+ - Connection pooling and keepalive optimization
+ - Semaphore-based concurrency control for batch processing
+ - Enhanced error handling with retryable error classification
"""
BASE_API_URL = "https://api.mistral.ai/v1"
@@ -53,17 +61,40 @@ class MistralLoader:
self.max_retries = max_retries
self.debug = enable_debug_logging
- # Pre-compute file info for performance
+ # PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations
+ # This prevents long-running OCR operations from affecting quick operations
+ # and improves user experience by failing fast on operations that should be quick
+ self.upload_timeout = min(
+ timeout, 120
+ ) # Cap upload at 2 minutes - prevents hanging on large files
+ self.url_timeout = (
+ 30 # URL requests should be fast - fail quickly if API is slow
+ )
+ self.ocr_timeout = (
+ timeout # OCR can take the full timeout - this is the heavy operation
+ )
+ self.cleanup_timeout = (
+ 30 # Cleanup should be quick - don't hang on file deletion
+ )
+
+ # PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls
+ # This avoids multiple os.path.basename() and os.path.getsize() calls during processing
self.file_name = os.path.basename(file_path)
self.file_size = os.path.getsize(file_path)
+ # ENHANCEMENT: Added User-Agent for better API tracking and debugging
self.headers = {
"Authorization": f"Bearer {self.api_key}",
- "User-Agent": "OpenWebUI-MistralLoader/2.0",
+ "User-Agent": "OpenWebUI-MistralLoader/2.0", # Helps API provider track usage
}
def _debug_log(self, message: str, *args) -> None:
- """Conditional debug logging for performance."""
+ """
+ PERFORMANCE OPTIMIZATION: Conditional debug logging for performance.
+
+ Only processes debug messages when debug mode is enabled, avoiding
+ string formatting overhead in production environments.
+ """
if self.debug:
log.debug(message, *args)
@@ -115,53 +146,118 @@ class MistralLoader:
log.error(f"Unexpected error processing response: {e}")
raise
+ def _is_retryable_error(self, error: Exception) -> bool:
+ """
+ ENHANCEMENT: Intelligent error classification for retry logic.
+
+ Determines if an error is retryable based on its type and status code.
+ This prevents wasting time retrying errors that will never succeed
+ (like authentication errors) while ensuring transient errors are retried.
+
+ Retryable errors:
+ - Network connection errors (temporary network issues)
+ - Timeouts (server might be temporarily overloaded)
+ - Server errors (5xx status codes - server-side issues)
+ - Rate limiting (429 status - temporary throttling)
+
+ Non-retryable errors:
+ - Authentication errors (401, 403 - won't fix with retry)
+ - Bad request errors (400 - malformed request)
+ - Not found errors (404 - resource doesn't exist)
+ """
+ if isinstance(error, requests.exceptions.ConnectionError):
+ return True # Network issues are usually temporary
+ if isinstance(error, requests.exceptions.Timeout):
+ return True # Timeouts might resolve on retry
+ if isinstance(error, requests.exceptions.HTTPError):
+ # Only retry on server errors (5xx) or rate limits (429)
+ if hasattr(error, "response") and error.response is not None:
+ status_code = error.response.status_code
+ return status_code >= 500 or status_code == 429
+ return False
+ if isinstance(
+ error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)
+ ):
+ return True # Async network/timeout errors are retryable
+ if isinstance(error, aiohttp.ClientResponseError):
+ return error.status >= 500 or error.status == 429
+ return False # All other errors are non-retryable
+
def _retry_request_sync(self, request_func, *args, **kwargs):
- """Synchronous retry logic with exponential backoff."""
+ """
+ ENHANCEMENT: Synchronous retry logic with intelligent error classification.
+
+ Uses exponential backoff with jitter to avoid thundering herd problems.
+ The wait time increases exponentially but is capped at 30 seconds to
+ prevent excessive delays. Only retries errors that are likely to succeed
+ on subsequent attempts.
+ """
for attempt in range(self.max_retries):
try:
return request_func(*args, **kwargs)
- except (requests.exceptions.RequestException, Exception) as e:
- if attempt == self.max_retries - 1:
+ except Exception as e:
+ if attempt == self.max_retries - 1 or not self._is_retryable_error(e):
raise
- wait_time = (2**attempt) + 0.5
+ # PERFORMANCE OPTIMIZATION: Exponential backoff with cap
+ # Prevents overwhelming the server while ensuring reasonable retry delays
+ wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
log.warning(
- f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
+ f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
+ f"Retrying in {wait_time}s..."
)
time.sleep(wait_time)
async def _retry_request_async(self, request_func, *args, **kwargs):
- """Async retry logic with exponential backoff."""
+ """
+ ENHANCEMENT: Async retry logic with intelligent error classification.
+
+ Async version of retry logic that doesn't block the event loop during
+ wait periods. Uses the same exponential backoff strategy as sync version.
+ """
for attempt in range(self.max_retries):
try:
return await request_func(*args, **kwargs)
- except (aiohttp.ClientError, asyncio.TimeoutError) as e:
- if attempt == self.max_retries - 1:
+ except Exception as e:
+ if attempt == self.max_retries - 1 or not self._is_retryable_error(e):
raise
- wait_time = (2**attempt) + 0.5
+ # PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff
+ wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
log.warning(
- f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
+ f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
+ f"Retrying in {wait_time}s..."
)
- await asyncio.sleep(wait_time)
+ await asyncio.sleep(wait_time) # Non-blocking wait
def _upload_file(self) -> str:
- """Uploads the file to Mistral for OCR processing (sync version)."""
+ """
+ PERFORMANCE OPTIMIZATION: Enhanced file upload with streaming consideration.
+
+ Uploads the file to Mistral for OCR processing (sync version).
+ Uses context manager for file handling to ensure proper resource cleanup.
+ Although streaming is not enabled for this endpoint, the file is opened
+ in a context manager to minimize memory usage duration.
+ """
log.info("Uploading file to Mistral API")
url = f"{self.BASE_API_URL}/files"
- file_name = os.path.basename(self.file_path)
def upload_request():
+ # MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime
+ # This ensures the file is closed immediately after reading, reducing memory usage
with open(self.file_path, "rb") as f:
- files = {"file": (file_name, f, "application/pdf")}
+ files = {"file": (self.file_name, f, "application/pdf")}
data = {"purpose": "ocr"}
+ # NOTE: stream=False is required for this endpoint
+ # The Mistral API doesn't support chunked uploads for this endpoint
response = requests.post(
url,
headers=self.headers,
files=files,
data=data,
- timeout=self.timeout,
+ timeout=self.upload_timeout, # Use specialized upload timeout
+ stream=False, # Keep as False for this endpoint
)
return self._handle_response(response)
@@ -209,7 +305,7 @@ class MistralLoader:
url,
data=writer,
headers=self.headers,
- timeout=aiohttp.ClientTimeout(total=self.timeout),
+ timeout=aiohttp.ClientTimeout(total=self.upload_timeout),
) as response:
return await self._handle_response_async(response)
@@ -231,7 +327,7 @@ class MistralLoader:
def url_request():
response = requests.get(
- url, headers=signed_url_headers, params=params, timeout=self.timeout
+ url, headers=signed_url_headers, params=params, timeout=self.url_timeout
)
return self._handle_response(response)
@@ -261,7 +357,7 @@ class MistralLoader:
url,
headers=headers,
params=params,
- timeout=aiohttp.ClientTimeout(total=self.timeout),
+ timeout=aiohttp.ClientTimeout(total=self.url_timeout),
) as response:
return await self._handle_response_async(response)
@@ -294,7 +390,7 @@ class MistralLoader:
def ocr_request():
response = requests.post(
- url, headers=ocr_headers, json=payload, timeout=self.timeout
+ url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout
)
return self._handle_response(response)
@@ -336,7 +432,7 @@ class MistralLoader:
url,
json=payload,
headers=headers,
- timeout=aiohttp.ClientTimeout(total=self.timeout),
+ timeout=aiohttp.ClientTimeout(total=self.ocr_timeout),
) as response:
ocr_response = await self._handle_response_async(response)
@@ -353,7 +449,9 @@ class MistralLoader:
url = f"{self.BASE_API_URL}/files/{file_id}"
try:
- response = requests.delete(url, headers=self.headers, timeout=30)
+ response = requests.delete(
+ url, headers=self.headers, timeout=self.cleanup_timeout
+ )
delete_response = self._handle_response(response)
log.info(f"File deleted successfully: {delete_response}")
except Exception as e:
@@ -372,7 +470,7 @@ class MistralLoader:
url=f"{self.BASE_API_URL}/files/{file_id}",
headers=self.headers,
timeout=aiohttp.ClientTimeout(
- total=30
+ total=self.cleanup_timeout
), # Shorter timeout for cleanup
) as response:
return await self._handle_response_async(response)
@@ -388,29 +486,39 @@ class MistralLoader:
async def _get_session(self):
"""Context manager for HTTP session with optimized settings."""
connector = aiohttp.TCPConnector(
- limit=10, # Total connection limit
- limit_per_host=5, # Per-host connection limit
- ttl_dns_cache=300, # DNS cache TTL
+ limit=20, # Increased total connection limit for better throughput
+ limit_per_host=10, # Increased per-host limit for API endpoints
+ ttl_dns_cache=600, # Longer DNS cache TTL (10 minutes)
use_dns_cache=True,
- keepalive_timeout=30,
+ keepalive_timeout=60, # Increased keepalive for connection reuse
enable_cleanup_closed=True,
+ force_close=False, # Allow connection reuse
+ resolver=aiohttp.AsyncResolver(), # Use async DNS resolver
+ )
+
+ timeout = aiohttp.ClientTimeout(
+ total=self.timeout,
+ connect=30, # Connection timeout
+ sock_read=60, # Socket read timeout
)
async with aiohttp.ClientSession(
connector=connector,
- timeout=aiohttp.ClientTimeout(total=self.timeout),
+ timeout=timeout,
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
+ raise_for_status=False, # We handle status codes manually
) as session:
yield session
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
- """Process OCR results into Document objects with enhanced metadata."""
+ """Process OCR results into Document objects with enhanced metadata and memory efficiency."""
pages_data = ocr_response.get("pages")
if not pages_data:
log.warning("No pages found in OCR response.")
return [
Document(
- page_content="No text content found", metadata={"error": "no_pages"}
+ page_content="No text content found",
+ metadata={"error": "no_pages", "file_name": self.file_name},
)
]
@@ -418,41 +526,44 @@ class MistralLoader:
total_pages = len(pages_data)
skipped_pages = 0
+ # Process pages in a memory-efficient way
for page_data in pages_data:
page_content = page_data.get("markdown")
page_index = page_data.get("index") # API uses 0-based index
- if page_content is not None and page_index is not None:
- # Clean up content efficiently
- cleaned_content = (
- page_content.strip()
- if isinstance(page_content, str)
- else str(page_content)
- )
-
- if cleaned_content: # Only add non-empty pages
- documents.append(
- Document(
- page_content=cleaned_content,
- metadata={
- "page": page_index, # 0-based index from API
- "page_label": page_index
- + 1, # 1-based label for convenience
- "total_pages": total_pages,
- "file_name": self.file_name,
- "file_size": self.file_size,
- "processing_engine": "mistral-ocr",
- },
- )
- )
- else:
- skipped_pages += 1
- self._debug_log(f"Skipping empty page {page_index}")
- else:
+ if page_content is None or page_index is None:
skipped_pages += 1
self._debug_log(
- f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}"
+ f"Skipping page due to missing 'markdown' or 'index'. Data keys: {list(page_data.keys())}"
)
+ continue
+
+ # Clean up content efficiently with early exit for empty content
+ if isinstance(page_content, str):
+ cleaned_content = page_content.strip()
+ else:
+ cleaned_content = str(page_content).strip()
+
+ if not cleaned_content:
+ skipped_pages += 1
+ self._debug_log(f"Skipping empty page {page_index}")
+ continue
+
+ # Create document with optimized metadata
+ documents.append(
+ Document(
+ page_content=cleaned_content,
+ metadata={
+ "page": page_index, # 0-based index from API
+ "page_label": page_index + 1, # 1-based label for convenience
+ "total_pages": total_pages,
+ "file_name": self.file_name,
+ "file_size": self.file_size,
+ "processing_engine": "mistral-ocr",
+ "content_length": len(cleaned_content),
+ },
+ )
+ )
if skipped_pages > 0:
log.info(
@@ -467,7 +578,11 @@ class MistralLoader:
return [
Document(
page_content="No valid text content found in document",
- metadata={"error": "no_valid_pages", "total_pages": total_pages},
+ metadata={
+ "error": "no_valid_pages",
+ "total_pages": total_pages,
+ "file_name": self.file_name,
+ },
)
]
@@ -585,12 +700,14 @@ class MistralLoader:
@staticmethod
async def load_multiple_async(
loaders: List["MistralLoader"],
+ max_concurrent: int = 5, # Limit concurrent requests
) -> List[List[Document]]:
"""
- Process multiple files concurrently for maximum performance.
+ Process multiple files concurrently with controlled concurrency.
Args:
loaders: List of MistralLoader instances
+ max_concurrent: Maximum number of concurrent requests
Returns:
List of document lists, one for each loader
@@ -598,11 +715,20 @@ class MistralLoader:
if not loaders:
return []
- log.info(f"Starting concurrent processing of {len(loaders)} files")
+ log.info(
+ f"Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent"
+ )
start_time = time.time()
- # Process all files concurrently
- tasks = [loader.load_async() for loader in loaders]
+ # Use semaphore to control concurrency
+ semaphore = asyncio.Semaphore(max_concurrent)
+
+ async def process_with_semaphore(loader: "MistralLoader") -> List[Document]:
+ async with semaphore:
+ return await loader.load_async()
+
+ # Process all files with controlled concurrency
+ tasks = [process_with_semaphore(loader) for loader in loaders]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Handle any exceptions in results
@@ -624,10 +750,18 @@ class MistralLoader:
else:
processed_results.append(result)
+ # MONITORING: Log comprehensive batch processing statistics
total_time = time.time() - start_time
total_docs = sum(len(docs) for docs in processed_results)
+ success_count = sum(
+ 1 for result in results if not isinstance(result, Exception)
+ )
+ failure_count = len(results) - success_count
+
log.info(
- f"Batch processing completed in {total_time:.2f}s, produced {total_docs} total documents"
+ f"Batch processing completed in {total_time:.2f}s: "
+ f"{success_count} files succeeded, {failure_count} files failed, "
+ f"produced {total_docs} total documents"
)
return processed_results
diff --git a/backend/open_webui/retrieval/loaders/youtube.py b/backend/open_webui/retrieval/loaders/youtube.py
index d908cc8cb5..be5e533588 100644
--- a/backend/open_webui/retrieval/loaders/youtube.py
+++ b/backend/open_webui/retrieval/loaders/youtube.py
@@ -1,4 +1,5 @@
import logging
+from xml.etree.ElementTree import ParseError
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
from urllib.parse import parse_qs, urlparse
@@ -93,7 +94,6 @@ class YoutubeLoader:
"http": self.proxy_url,
"https": self.proxy_url,
}
- # Don't log complete URL because it might contain secrets
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
else:
youtube_proxies = None
@@ -110,11 +110,37 @@ class YoutubeLoader:
for lang in self.language:
try:
transcript = transcript_list.find_transcript([lang])
+ if transcript.is_generated:
+ log.debug(f"Found generated transcript for language '{lang}'")
+ try:
+ transcript = transcript_list.find_manually_created_transcript(
+ [lang]
+ )
+ log.debug(f"Found manual transcript for language '{lang}'")
+ except NoTranscriptFound:
+ log.debug(
+ f"No manual transcript found for language '{lang}', using generated"
+ )
+ pass
+
log.debug(f"Found transcript for language '{lang}'")
- transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
+ try:
+ transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
+ except ParseError:
+ log.debug(f"Empty or invalid transcript for language '{lang}'")
+ continue
+
+ if not transcript_pieces:
+ log.debug(f"Empty transcript for language '{lang}'")
+ continue
+
transcript_text = " ".join(
map(
- lambda transcript_piece: transcript_piece.text.strip(" "),
+ lambda transcript_piece: (
+ transcript_piece.text.strip(" ")
+ if hasattr(transcript_piece, "text")
+ else ""
+ ),
transcript_pieces,
)
)
@@ -131,6 +157,4 @@ class YoutubeLoader:
log.warning(
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
)
- raise NoTranscriptFound(
- f"No transcript found for any supported language. Verify if the video has transcripts, add more languages if needed."
- )
+ raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))
diff --git a/backend/open_webui/retrieval/vector/dbs/pinecone.py b/backend/open_webui/retrieval/vector/dbs/pinecone.py
index 9f8abf4609..8291332c0f 100644
--- a/backend/open_webui/retrieval/vector/dbs/pinecone.py
+++ b/backend/open_webui/retrieval/vector/dbs/pinecone.py
@@ -3,10 +3,19 @@ import logging
import time # for measuring elapsed time
from pinecone import Pinecone, ServerlessSpec
+# Add gRPC support for better performance (Pinecone best practice)
+try:
+ from pinecone.grpc import PineconeGRPC
+
+ GRPC_AVAILABLE = True
+except ImportError:
+ GRPC_AVAILABLE = False
+
import asyncio # for async upserts
import functools # for partial binding in async tasks
import concurrent.futures # for parallel batch upserts
+import random # for jitter in retry backoff
from open_webui.retrieval.vector.main import (
VectorDBBase,
@@ -47,7 +56,24 @@ class PineconeClient(VectorDBBase):
self.cloud = PINECONE_CLOUD
# Initialize Pinecone client for improved performance
- self.client = Pinecone(api_key=self.api_key)
+ if GRPC_AVAILABLE:
+ # Use gRPC client for better performance (Pinecone recommendation)
+ self.client = PineconeGRPC(
+ api_key=self.api_key,
+ pool_threads=20, # Improved connection pool size
+ timeout=30, # Reasonable timeout for operations
+ )
+ self.using_grpc = True
+ log.info("Using Pinecone gRPC client for optimal performance")
+ else:
+ # Fallback to HTTP client with enhanced connection pooling
+ self.client = Pinecone(
+ api_key=self.api_key,
+ pool_threads=20, # Improved connection pool size
+ timeout=30, # Reasonable timeout for operations
+ )
+ self.using_grpc = False
+ log.info("Using Pinecone HTTP client (gRPC not available)")
# Persistent executor for batch operations
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
@@ -91,12 +117,53 @@ class PineconeClient(VectorDBBase):
log.info(f"Using existing Pinecone index '{self.index_name}'")
# Connect to the index
- self.index = self.client.Index(self.index_name)
+ self.index = self.client.Index(
+ self.index_name,
+ pool_threads=20, # Enhanced connection pool for index operations
+ )
except Exception as e:
log.error(f"Failed to initialize Pinecone index: {e}")
raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
+ def _retry_pinecone_operation(self, operation_func, max_retries=3):
+ """Retry Pinecone operations with exponential backoff for rate limits and network issues."""
+ for attempt in range(max_retries):
+ try:
+ return operation_func()
+ except Exception as e:
+ error_str = str(e).lower()
+ # Check if it's a retryable error (rate limits, network issues, timeouts)
+ is_retryable = any(
+ keyword in error_str
+ for keyword in [
+ "rate limit",
+ "quota",
+ "timeout",
+ "network",
+ "connection",
+ "unavailable",
+ "internal error",
+ "429",
+ "500",
+ "502",
+ "503",
+ "504",
+ ]
+ )
+
+ if not is_retryable or attempt == max_retries - 1:
+ # Don't retry for non-retryable errors or on final attempt
+ raise
+
+ # Exponential backoff with jitter
+ delay = (2**attempt) + random.uniform(0, 1)
+ log.warning(
+ f"Pinecone operation failed (attempt {attempt + 1}/{max_retries}), "
+ f"retrying in {delay:.2f}s: {e}"
+ )
+ time.sleep(delay)
+
def _create_points(
self, items: List[VectorItem], collection_name_with_prefix: str
) -> List[Dict[str, Any]]:
@@ -223,7 +290,8 @@ class PineconeClient(VectorDBBase):
elapsed = time.time() - start_time
log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
log.info(
- f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
+ f"Successfully inserted {len(points)} vectors in parallel batches "
+ f"into '{collection_name_with_prefix}'"
)
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
@@ -254,7 +322,8 @@ class PineconeClient(VectorDBBase):
elapsed = time.time() - start_time
log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
log.info(
- f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
+ f"Successfully upserted {len(points)} vectors in parallel batches "
+ f"into '{collection_name_with_prefix}'"
)
async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
@@ -285,7 +354,8 @@ class PineconeClient(VectorDBBase):
log.error(f"Error in async insert batch: {result}")
raise result
log.info(
- f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
+ f"Successfully async inserted {len(points)} vectors in batches "
+ f"into '{collection_name_with_prefix}'"
)
async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
@@ -316,7 +386,8 @@ class PineconeClient(VectorDBBase):
log.error(f"Error in async upsert batch: {result}")
raise result
log.info(
- f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
+ f"Successfully async upserted {len(points)} vectors in batches "
+ f"into '{collection_name_with_prefix}'"
)
def search(
@@ -457,10 +528,12 @@ class PineconeClient(VectorDBBase):
# This is a limitation of Pinecone - be careful with ID uniqueness
self.index.delete(ids=batch_ids)
log.debug(
- f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
+ f"Deleted batch of {len(batch_ids)} vectors by ID "
+ f"from '{collection_name_with_prefix}'"
)
log.info(
- f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
+ f"Successfully deleted {len(ids)} vectors by ID "
+ f"from '{collection_name_with_prefix}'"
)
elif filter:
diff --git a/backend/open_webui/retrieval/web/perplexity.py b/backend/open_webui/retrieval/web/perplexity.py
index e5314eb1f7..4e046668fa 100644
--- a/backend/open_webui/retrieval/web/perplexity.py
+++ b/backend/open_webui/retrieval/web/perplexity.py
@@ -1,10 +1,20 @@
import logging
-from typing import Optional, List
+from typing import Optional, Literal
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
+MODELS = Literal[
+ "sonar",
+ "sonar-pro",
+ "sonar-reasoning",
+ "sonar-reasoning-pro",
+ "sonar-deep-research",
+]
+SEARCH_CONTEXT_USAGE_LEVELS = Literal["low", "medium", "high"]
+
+
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -14,6 +24,8 @@ def search_perplexity(
query: str,
count: int,
filter_list: Optional[list[str]] = None,
+ model: MODELS = "sonar",
+ search_context_usage: SEARCH_CONTEXT_USAGE_LEVELS = "medium",
) -> list[SearchResult]:
"""Search using Perplexity API and return the results as a list of SearchResult objects.
@@ -21,6 +33,9 @@ def search_perplexity(
api_key (str): A Perplexity API key
query (str): The query to search for
count (int): Maximum number of results to return
+ filter_list (Optional[list[str]]): List of domains to filter results
+ model (str): The Perplexity model to use (sonar, sonar-pro)
+ search_context_usage (str): Search context usage level (low, medium, high)
"""
@@ -33,7 +48,7 @@ def search_perplexity(
# Create payload for the API call
payload = {
- "model": "sonar",
+ "model": model,
"messages": [
{
"role": "system",
@@ -43,6 +58,9 @@ def search_perplexity(
],
"temperature": 0.2, # Lower temperature for more factual responses
"stream": False,
+ "web_search_options": {
+ "search_context_usage": search_context_usage,
+ },
}
headers = {
diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py
index c6d8e41864..52686a5841 100644
--- a/backend/open_webui/routers/images.py
+++ b/backend/open_webui/routers/images.py
@@ -420,7 +420,7 @@ def load_b64_image_data(b64_str):
try:
if "," in b64_str:
header, encoded = b64_str.split(",", 1)
- mime_type = header.split(";")[0]
+ mime_type = header.split(";")[0].lstrip("data:")
img_data = base64.b64decode(encoded)
else:
mime_type = "image/png"
@@ -428,7 +428,7 @@ def load_b64_image_data(b64_str):
return img_data, mime_type
except Exception as e:
log.exception(f"Error loading image data: {e}")
- return None
+ return None, None
def load_url_image_data(url, headers=None):
diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py
index 5ad5ff051e..94f8325d70 100644
--- a/backend/open_webui/routers/notes.py
+++ b/backend/open_webui/routers/notes.py
@@ -124,9 +124,8 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
- if (
- user.role != "admin"
- and user.id != note.user_id
+ if user.role != "admin" or (
+ user.id != note.user_id
and not has_access(user.id, type="read", access_control=note.access_control)
):
raise HTTPException(
@@ -159,9 +158,8 @@ async def update_note_by_id(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
- if (
- user.role != "admin"
- and user.id != note.user_id
+ if user.role != "admin" or (
+ user.id != note.user_id
and not has_access(user.id, type="write", access_control=note.access_control)
):
raise HTTPException(
@@ -199,9 +197,8 @@ async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
- if (
- user.role != "admin"
- and user.id != note.user_id
+ if user.role != "admin" or (
+ user.id != note.user_id
and not has_access(user.id, type="write", access_control=note.access_control)
):
raise HTTPException(
diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py
index 9c3c393677..7649271fee 100644
--- a/backend/open_webui/routers/openai.py
+++ b/backend/open_webui/routers/openai.py
@@ -887,6 +887,88 @@ async def generate_chat_completion(
await session.close()
+async def embeddings(request: Request, form_data: dict, user):
+ """
+ Calls the embeddings endpoint for OpenAI-compatible providers.
+
+ Args:
+ request (Request): The FastAPI request context.
+ form_data (dict): OpenAI-compatible embeddings payload.
+ user (UserModel): The authenticated user.
+
+ Returns:
+ dict: OpenAI-compatible embeddings response.
+ """
+ idx = 0
+ # Prepare payload/body
+ body = json.dumps(form_data)
+ # Find correct backend url/key based on model
+ await get_all_models(request, user=user)
+ model_id = form_data.get("model")
+ models = request.app.state.OPENAI_MODELS
+ if model_id in models:
+ idx = models[model_id]["urlIdx"]
+ url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
+ key = request.app.state.config.OPENAI_API_KEYS[idx]
+ r = None
+ session = None
+ streaming = False
+ try:
+ session = aiohttp.ClientSession(trust_env=True)
+ r = await session.request(
+ method="POST",
+ url=f"{url}/embeddings",
+ data=body,
+ headers={
+ "Authorization": f"Bearer {key}",
+ "Content-Type": "application/json",
+ **(
+ {
+ "X-OpenWebUI-User-Name": user.name,
+ "X-OpenWebUI-User-Id": user.id,
+ "X-OpenWebUI-User-Email": user.email,
+ "X-OpenWebUI-User-Role": user.role,
+ }
+ if ENABLE_FORWARD_USER_INFO_HEADERS and user
+ else {}
+ ),
+ },
+ )
+ r.raise_for_status()
+ if "text/event-stream" in r.headers.get("Content-Type", ""):
+ streaming = True
+ return StreamingResponse(
+ r.content,
+ status_code=r.status,
+ headers=dict(r.headers),
+ background=BackgroundTask(
+ cleanup_response, response=r, session=session
+ ),
+ )
+ else:
+ response_data = await r.json()
+ return response_data
+ except Exception as e:
+ log.exception(e)
+ detail = None
+ if r is not None:
+ try:
+ res = await r.json()
+ if "error" in res:
+ detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
+ except Exception:
+ detail = f"External: {e}"
+ raise HTTPException(
+ status_code=r.status if r else 500,
+ detail=detail if detail else "Open WebUI: Server Connection Error",
+ )
+ finally:
+ if not streaming and session:
+ if r:
+ r.close()
+ await session.close()
+
+
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
"""
diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py
index 343b0513c9..2bd73c25e3 100644
--- a/backend/open_webui/routers/retrieval.py
+++ b/backend/open_webui/routers/retrieval.py
@@ -414,6 +414,9 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
+ "DOCLING_PICTURE_DESCRIPTION_MODE": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
+ "DOCLING_PICTURE_DESCRIPTION_LOCAL": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
+ "DOCLING_PICTURE_DESCRIPTION_API": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
@@ -467,6 +470,8 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
"EXA_API_KEY": request.app.state.config.EXA_API_KEY,
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
+ "PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL,
+ "PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
@@ -520,6 +525,8 @@ class WebConfig(BaseModel):
BING_SEARCH_V7_SUBSCRIPTION_KEY: Optional[str] = None
EXA_API_KEY: Optional[str] = None
PERPLEXITY_API_KEY: Optional[str] = None
+ PERPLEXITY_MODEL: Optional[str] = None
+ PERPLEXITY_SEARCH_CONTEXT_USAGE: Optional[str] = None
SOUGOU_API_SID: Optional[str] = None
SOUGOU_API_SK: Optional[str] = None
WEB_LOADER_ENGINE: Optional[str] = None
@@ -571,6 +578,9 @@ class ConfigForm(BaseModel):
DOCLING_OCR_ENGINE: Optional[str] = None
DOCLING_OCR_LANG: Optional[str] = None
DOCLING_DO_PICTURE_DESCRIPTION: Optional[bool] = None
+ DOCLING_PICTURE_DESCRIPTION_MODE: Optional[str] = None
+ DOCLING_PICTURE_DESCRIPTION_LOCAL: Optional[dict] = None
+ DOCLING_PICTURE_DESCRIPTION_API: Optional[dict] = None
DOCUMENT_INTELLIGENCE_ENDPOINT: Optional[str] = None
DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None
MISTRAL_OCR_API_KEY: Optional[str] = None
@@ -744,6 +754,22 @@ async def update_rag_config(
else request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION
)
+ request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE = (
+ form_data.DOCLING_PICTURE_DESCRIPTION_MODE
+ if form_data.DOCLING_PICTURE_DESCRIPTION_MODE is not None
+ else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE
+ )
+ request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL = (
+ form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL
+ if form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL is not None
+ else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL
+ )
+ request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API = (
+ form_data.DOCLING_PICTURE_DESCRIPTION_API
+ if form_data.DOCLING_PICTURE_DESCRIPTION_API is not None
+ else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API
+ )
+
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
form_data.DOCUMENT_INTELLIGENCE_ENDPOINT
if form_data.DOCUMENT_INTELLIGENCE_ENDPOINT is not None
@@ -907,6 +933,10 @@ async def update_rag_config(
)
request.app.state.config.EXA_API_KEY = form_data.web.EXA_API_KEY
request.app.state.config.PERPLEXITY_API_KEY = form_data.web.PERPLEXITY_API_KEY
+ request.app.state.config.PERPLEXITY_MODEL = form_data.web.PERPLEXITY_MODEL
+ request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = (
+ form_data.web.PERPLEXITY_SEARCH_CONTEXT_USAGE
+ )
request.app.state.config.SOUGOU_API_SID = form_data.web.SOUGOU_API_SID
request.app.state.config.SOUGOU_API_SK = form_data.web.SOUGOU_API_SK
@@ -977,6 +1007,9 @@ async def update_rag_config(
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
+ "DOCLING_PICTURE_DESCRIPTION_MODE": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
+ "DOCLING_PICTURE_DESCRIPTION_LOCAL": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
+ "DOCLING_PICTURE_DESCRIPTION_API": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
@@ -1030,6 +1063,8 @@ async def update_rag_config(
"BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
"EXA_API_KEY": request.app.state.config.EXA_API_KEY,
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
+ "PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL,
+ "PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
@@ -1321,9 +1356,14 @@ def process_file(
EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
- DOCLING_OCR_ENGINE=request.app.state.config.DOCLING_OCR_ENGINE,
- DOCLING_OCR_LANG=request.app.state.config.DOCLING_OCR_LANG,
- DOCLING_DO_PICTURE_DESCRIPTION=request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
+ DOCLING_PARAMS={
+ "ocr_engine": request.app.state.config.DOCLING_OCR_ENGINE,
+ "ocr_lang": request.app.state.config.DOCLING_OCR_LANG,
+ "do_picture_description": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
+ "picture_description_mode": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
+ "picture_description_local": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
+ "picture_description_api": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
+ },
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
@@ -1740,19 +1780,14 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
)
- elif engine == "exa":
- return search_exa(
- request.app.state.config.EXA_API_KEY,
- query,
- request.app.state.config.WEB_SEARCH_RESULT_COUNT,
- request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
- )
elif engine == "perplexity":
return search_perplexity(
request.app.state.config.PERPLEXITY_API_KEY,
query,
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
+ model=request.app.state.config.PERPLEXITY_MODEL,
+ search_context_usage=request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
)
elif engine == "sougou":
if (
diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py
index f94346099e..3832c0306b 100644
--- a/backend/open_webui/routers/tasks.py
+++ b/backend/open_webui/routers/tasks.py
@@ -9,6 +9,7 @@ import re
from open_webui.utils.chat import generate_chat_completion
from open_webui.utils.task import (
title_generation_template,
+ follow_up_generation_template,
query_generation_template,
image_prompt_generation_template,
autocomplete_generation_template,
@@ -25,6 +26,7 @@ from open_webui.utils.task import get_task_model_id
from open_webui.config import (
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
+ DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
@@ -58,6 +60,8 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
+ "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
+ "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
@@ -76,6 +80,8 @@ class TaskConfigForm(BaseModel):
ENABLE_AUTOCOMPLETE_GENERATION: bool
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
TAGS_GENERATION_PROMPT_TEMPLATE: str
+ FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str
+ ENABLE_FOLLOW_UP_GENERATION: bool
ENABLE_TAGS_GENERATION: bool
ENABLE_SEARCH_QUERY_GENERATION: bool
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
@@ -94,6 +100,13 @@ async def update_task_config(
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
)
+ request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = (
+ form_data.ENABLE_FOLLOW_UP_GENERATION
+ )
+ request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
+ form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
+ )
+
request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
)
@@ -133,6 +146,8 @@ async def update_task_config(
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
+ "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
+ "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
@@ -231,6 +246,86 @@ async def generate_title(
)
+@router.post("/follow_up/completions")
+async def generate_follow_ups(
+ request: Request, form_data: dict, user=Depends(get_verified_user)
+):
+
+ if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION:
+ return JSONResponse(
+ status_code=status.HTTP_200_OK,
+ content={"detail": "Follow-up generation is disabled"},
+ )
+
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
+
+ model_id = form_data["model"]
+ if model_id not in models:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Model not found",
+ )
+
+ # Check if the user has a custom task model
+ # If the user has a custom task model, use that model
+ task_model_id = get_task_model_id(
+ model_id,
+ request.app.state.config.TASK_MODEL,
+ request.app.state.config.TASK_MODEL_EXTERNAL,
+ models,
+ )
+
+ log.debug(
+ f"generating chat title using model {task_model_id} for user {user.email} "
+ )
+
+ if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "":
+ template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
+ else:
+ template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
+
+ content = follow_up_generation_template(
+ template,
+ form_data["messages"],
+ {
+ "name": user.name,
+ "location": user.info.get("location") if user.info else None,
+ },
+ )
+
+ payload = {
+ "model": task_model_id,
+ "messages": [{"role": "user", "content": content}],
+ "stream": False,
+ "metadata": {
+ **(request.state.metadata if hasattr(request.state, "metadata") else {}),
+ "task": str(TASKS.FOLLOW_UP_GENERATION),
+ "task_body": form_data,
+ "chat_id": form_data.get("chat_id", None),
+ },
+ }
+
+ # Process the payload through the pipeline
+ try:
+ payload = await process_pipeline_inlet_filter(request, payload, user, models)
+ except Exception as e:
+ raise e
+
+ try:
+ return await generate_chat_completion(request, form_data=payload, user=user)
+ except Exception as e:
+ log.error("Exception occurred", exc_info=True)
+ return JSONResponse(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ content={"detail": "An internal error has occurred."},
+ )
+
+
@router.post("/tags/completions")
async def generate_chat_tags(
request: Request, form_data: dict, user=Depends(get_verified_user)
diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py
index 8702ae50ba..4046dc72d8 100644
--- a/backend/open_webui/routers/users.py
+++ b/backend/open_webui/routers/users.py
@@ -165,22 +165,6 @@ async def update_default_user_permissions(
return request.app.state.config.USER_PERMISSIONS
-############################
-# UpdateUserRole
-############################
-
-
-@router.post("/update/role", response_model=Optional[UserModel])
-async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
- if user.id != form_data.id and form_data.id != Users.get_first_user().id:
- return Users.update_user_role_by_id(form_data.id, form_data.role)
-
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail=ERROR_MESSAGES.ACTION_PROHIBITED,
- )
-
-
############################
# GetUserSettingsBySessionUser
############################
@@ -333,11 +317,22 @@ async def update_user_by_id(
# Prevent modification of the primary admin user by other admins
try:
first_user = Users.get_first_user()
- if first_user and user_id == first_user.id and session_user.id != user_id:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail=ERROR_MESSAGES.ACTION_PROHIBITED,
- )
+ if first_user:
+ if user_id == first_user.id:
+ if session_user.id != user_id:
+ # If the user trying to update is the primary admin, and they are not the primary admin themselves
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=ERROR_MESSAGES.ACTION_PROHIBITED,
+ )
+
+ if form_data.role != "admin":
+ # If the primary admin is trying to change their own role, prevent it
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=ERROR_MESSAGES.ACTION_PROHIBITED,
+ )
+
except Exception as e:
log.error(f"Error checking primary admin status: {e}")
raise HTTPException(
@@ -365,6 +360,7 @@ async def update_user_by_id(
updated_user = Users.update_user_by_id(
user_id,
{
+ "role": form_data.role,
"name": form_data.name,
"email": form_data.email.lower(),
"profile_image_url": form_data.profile_image_url,
diff --git a/backend/open_webui/routers/utils.py b/backend/open_webui/routers/utils.py
index b64adafb44..0e6768a671 100644
--- a/backend/open_webui/routers/utils.py
+++ b/backend/open_webui/routers/utils.py
@@ -33,7 +33,7 @@ class CodeForm(BaseModel):
@router.post("/code/format")
-async def format_code(form_data: CodeForm, user=Depends(get_verified_user)):
+async def format_code(form_data: CodeForm, user=Depends(get_admin_user)):
try:
formatted_code = black.format_str(form_data.code, mode=black.Mode())
return {"code": formatted_code}
diff --git a/backend/open_webui/tasks.py b/backend/open_webui/tasks.py
index e575e6885c..0923159fb0 100644
--- a/backend/open_webui/tasks.py
+++ b/backend/open_webui/tasks.py
@@ -2,16 +2,90 @@
import asyncio
from typing import Dict
from uuid import uuid4
+import json
+from redis import Redis
+from fastapi import Request
+from typing import Dict, List, Optional
# A dictionary to keep track of active tasks
tasks: Dict[str, asyncio.Task] = {}
chat_tasks = {}
-def cleanup_task(task_id: str, id=None):
+REDIS_TASKS_KEY = "open-webui:tasks"
+REDIS_CHAT_TASKS_KEY = "open-webui:tasks:chat"
+REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands"
+
+
+def is_redis(request: Request) -> bool:
+ # Called everywhere a request is available to check Redis
+ return hasattr(request.app.state, "redis") and isinstance(
+ request.app.state.redis, Redis
+ )
+
+
+async def redis_task_command_listener(app):
+ redis: Redis = app.state.redis
+ pubsub = redis.pubsub()
+ await pubsub.subscribe(REDIS_PUBSUB_CHANNEL)
+ print("Subscribed to Redis task command channel")
+
+ async for message in pubsub.listen():
+ if message["type"] != "message":
+ continue
+ try:
+ command = json.loads(message["data"])
+ if command.get("action") == "stop":
+ task_id = command.get("task_id")
+ local_task = tasks.get(task_id)
+ if local_task:
+ local_task.cancel()
+ except Exception as e:
+ print(f"Error handling distributed task command: {e}")
+
+
+### ------------------------------
+### REDIS-ENABLED HANDLERS
+### ------------------------------
+
+
+def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]):
+ pipe = redis.pipeline()
+ pipe.hset(REDIS_TASKS_KEY, task_id, chat_id or "")
+ if chat_id:
+ pipe.sadd(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
+ pipe.execute()
+
+
+def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]):
+ pipe = redis.pipeline()
+ pipe.hdel(REDIS_TASKS_KEY, task_id)
+ if chat_id:
+ pipe.srem(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
+ if pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute()[-1] == 0:
+ pipe.delete(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}") # Remove if empty set
+ pipe.execute()
+
+
+def redis_list_tasks(redis: Redis) -> List[str]:
+ return list(redis.hkeys(REDIS_TASKS_KEY))
+
+
+def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]:
+ return list(redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}"))
+
+
+def redis_send_command(redis: Redis, command: dict):
+ redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
+
+
+def cleanup_task(request, task_id: str, id=None):
"""
Remove a completed or canceled task from the global `tasks` dictionary.
"""
+ if is_redis(request):
+ redis_cleanup_task(request.app.state.redis, task_id, id)
+
tasks.pop(task_id, None) # Remove the task if it exists
# If an ID is provided, remove the task from the chat_tasks dictionary
@@ -21,7 +95,7 @@ def cleanup_task(task_id: str, id=None):
chat_tasks.pop(id, None)
-def create_task(coroutine, id=None):
+def create_task(request, coroutine, id=None):
"""
Create a new asyncio task and add it to the global task dictionary.
"""
@@ -29,7 +103,7 @@ def create_task(coroutine, id=None):
task = asyncio.create_task(coroutine) # Create the task
# Add a done callback for cleanup
- task.add_done_callback(lambda t: cleanup_task(task_id, id))
+ task.add_done_callback(lambda t: cleanup_task(request, task_id, id))
tasks[task_id] = task
# If an ID is provided, associate the task with that ID
@@ -38,34 +112,46 @@ def create_task(coroutine, id=None):
else:
chat_tasks[id] = [task_id]
+ if is_redis(request):
+ redis_save_task(request.app.state.redis, task_id, id)
+
return task_id, task
-def get_task(task_id: str):
- """
- Retrieve a task by its task ID.
- """
- return tasks.get(task_id)
-
-
-def list_tasks():
+def list_tasks(request):
"""
List all currently active task IDs.
"""
+ if is_redis(request):
+ return redis_list_tasks(request.app.state.redis)
return list(tasks.keys())
-def list_task_ids_by_chat_id(id):
+def list_task_ids_by_chat_id(request, id):
"""
List all tasks associated with a specific ID.
"""
+ if is_redis(request):
+ return redis_list_chat_tasks(request.app.state.redis, id)
return chat_tasks.get(id, [])
-async def stop_task(task_id: str):
+async def stop_task(request, task_id: str):
"""
Cancel a running task and remove it from the global task list.
"""
+ if is_redis(request):
+ # PUBSUB: All instances check if they have this task, and stop if so.
+ redis_send_command(
+ request.app.state.redis,
+ {
+ "action": "stop",
+ "task_id": task_id,
+ },
+ )
+ # Optionally check if task_id still in Redis a few moments later for feedback?
+ return {"status": True, "message": f"Stop signal sent for {task_id}"}
+
task = tasks.get(task_id)
if not task:
raise ValueError(f"Task with ID {task_id} not found.")
diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py
index 4bd744e3c3..268c910e3e 100644
--- a/backend/open_webui/utils/chat.py
+++ b/backend/open_webui/utils/chat.py
@@ -320,12 +320,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
extra_params = {
"__event_emitter__": get_event_emitter(metadata),
"__event_call__": get_event_call(metadata),
- "__user__": {
- "id": user.id,
- "email": user.email,
- "name": user.name,
- "role": user.role,
- },
+ "__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__request__": request,
"__model__": model,
@@ -424,12 +419,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
params[key] = value
if "__user__" in sig.parameters:
- __user__ = {
- "id": user.id,
- "email": user.email,
- "name": user.name,
- "role": user.role,
- }
+ __user__ = (user.model_dump() if isinstance(user, UserModel) else {},)
try:
if hasattr(function_module, "UserValves"):
diff --git a/backend/open_webui/utils/embeddings.py b/backend/open_webui/utils/embeddings.py
new file mode 100644
index 0000000000..49ce72c3c5
--- /dev/null
+++ b/backend/open_webui/utils/embeddings.py
@@ -0,0 +1,90 @@
+import random
+import logging
+import sys
+
+from fastapi import Request
+from open_webui.models.users import UserModel
+from open_webui.models.models import Models
+from open_webui.utils.models import check_model_access
+from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
+
+from open_webui.routers.openai import embeddings as openai_embeddings
+from open_webui.routers.ollama import (
+ embeddings as ollama_embeddings,
+ GenerateEmbeddingsForm,
+)
+
+
+from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
+from open_webui.utils.response import convert_embedding_response_ollama_to_openai
+
+logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MAIN"])
+
+
+async def generate_embeddings(
+ request: Request,
+ form_data: dict,
+ user: UserModel,
+ bypass_filter: bool = False,
+):
+ """
+ Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama).
+
+ Args:
+ request (Request): The FastAPI request context.
+ form_data (dict): The input data sent to the endpoint.
+ user (UserModel): The authenticated user.
+ bypass_filter (bool): If True, disables access filtering (default False).
+
+ Returns:
+ dict: The embeddings response, following OpenAI API compatibility.
+ """
+ if BYPASS_MODEL_ACCESS_CONTROL:
+ bypass_filter = True
+
+ # Attach extra metadata from request.state if present
+ if hasattr(request.state, "metadata"):
+ if "metadata" not in form_data:
+ form_data["metadata"] = request.state.metadata
+ else:
+ form_data["metadata"] = {
+ **form_data["metadata"],
+ **request.state.metadata,
+ }
+
+ # If "direct" flag present, use only that model
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
+
+ model_id = form_data.get("model")
+ if model_id not in models:
+ raise Exception("Model not found")
+ model = models[model_id]
+
+ # Access filtering
+ if not getattr(request.state, "direct", False):
+ if not bypass_filter and user.role == "user":
+ check_model_access(user, model)
+
+ # Ollama backend
+ if model.get("owned_by") == "ollama":
+ ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
+ response = await ollama_embeddings(
+ request=request,
+ form_data=GenerateEmbeddingsForm(**ollama_payload),
+ user=user,
+ )
+ return convert_embedding_response_ollama_to_openai(response)
+
+ # Default: OpenAI or compatible backend
+ return await openai_embeddings(
+ request=request,
+ form_data=form_data,
+ user=user,
+ )
diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py
index 7b5659d514..a5a9b8e078 100644
--- a/backend/open_webui/utils/middleware.py
+++ b/backend/open_webui/utils/middleware.py
@@ -32,11 +32,17 @@ from open_webui.socket.main import (
from open_webui.routers.tasks import (
generate_queries,
generate_title,
+ generate_follow_ups,
generate_image_prompt,
generate_chat_tags,
)
from open_webui.routers.retrieval import process_web_search, SearchForm
-from open_webui.routers.images import image_generations, GenerateImageForm
+from open_webui.routers.images import (
+ load_b64_image_data,
+ image_generations,
+ GenerateImageForm,
+ upload_image,
+)
from open_webui.routers.pipelines import (
process_pipeline_inlet_filter,
process_pipeline_outlet_filter,
@@ -726,12 +732,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
extra_params = {
"__event_emitter__": event_emitter,
"__event_call__": event_call,
- "__user__": {
- "id": user.id,
- "email": user.email,
- "name": user.name,
- "role": user.role,
- },
+ "__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__request__": request,
"__model__": model,
@@ -1048,6 +1049,59 @@ async def process_chat_response(
)
if tasks and messages:
+ if (
+ TASKS.FOLLOW_UP_GENERATION in tasks
+ and tasks[TASKS.FOLLOW_UP_GENERATION]
+ ):
+ res = await generate_follow_ups(
+ request,
+ {
+ "model": message["model"],
+ "messages": messages,
+ "message_id": metadata["message_id"],
+ "chat_id": metadata["chat_id"],
+ },
+ user,
+ )
+
+ if res and isinstance(res, dict):
+ if len(res.get("choices", [])) == 1:
+ follow_ups_string = (
+ res.get("choices", [])[0]
+ .get("message", {})
+ .get("content", "")
+ )
+ else:
+ follow_ups_string = ""
+
+ follow_ups_string = follow_ups_string[
+ follow_ups_string.find("{") : follow_ups_string.rfind("}")
+ + 1
+ ]
+
+ try:
+ follow_ups = json.loads(follow_ups_string).get(
+ "follow_ups", []
+ )
+ Chats.upsert_message_to_chat_by_id_and_message_id(
+ metadata["chat_id"],
+ metadata["message_id"],
+ {
+ "followUps": follow_ups,
+ },
+ )
+
+ await event_emitter(
+ {
+ "type": "chat:message:follow_ups",
+ "data": {
+ "follow_ups": follow_ups,
+ },
+ }
+ )
+ except Exception as e:
+ pass
+
if TASKS.TITLE_GENERATION in tasks:
if tasks[TASKS.TITLE_GENERATION]:
res = await generate_title(
@@ -1273,12 +1327,7 @@ async def process_chat_response(
extra_params = {
"__event_emitter__": event_emitter,
"__event_call__": event_caller,
- "__user__": {
- "id": user.id,
- "email": user.email,
- "name": user.name,
- "role": user.role,
- },
+ "__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__request__": request,
"__model__": model,
@@ -2215,28 +2264,21 @@ async def process_chat_response(
stdoutLines = stdout.split("\n")
for idx, line in enumerate(stdoutLines):
if "data:image/png;base64" in line:
- id = str(uuid4())
-
- # ensure the path exists
- os.makedirs(
- os.path.join(CACHE_DIR, "images"),
- exist_ok=True,
+ image_url = ""
+ # Extract base64 image data from the line
+ image_data, content_type = (
+ load_b64_image_data(line)
)
-
- image_path = os.path.join(
- CACHE_DIR,
- f"images/{id}.png",
- )
-
- with open(image_path, "wb") as f:
- f.write(
- base64.b64decode(
- line.split(",")[1]
- )
+ if image_data is not None:
+ image_url = upload_image(
+ request,
+ image_data,
+ content_type,
+ metadata,
+ user,
)
-
stdoutLines[idx] = (
- f""
+ f""
)
output["stdout"] = "\n".join(stdoutLines)
@@ -2247,30 +2289,22 @@ async def process_chat_response(
resultLines = result.split("\n")
for idx, line in enumerate(resultLines):
if "data:image/png;base64" in line:
- id = str(uuid4())
-
- # ensure the path exists
- os.makedirs(
- os.path.join(CACHE_DIR, "images"),
- exist_ok=True,
+ image_url = ""
+ # Extract base64 image data from the line
+ image_data, content_type = (
+ load_b64_image_data(line)
)
-
- image_path = os.path.join(
- CACHE_DIR,
- f"images/{id}.png",
- )
-
- with open(image_path, "wb") as f:
- f.write(
- base64.b64decode(
- line.split(",")[1]
- )
+ if image_data is not None:
+ image_url = upload_image(
+ request,
+ image_data,
+ content_type,
+ metadata,
+ user,
)
-
resultLines[idx] = (
- f""
+ f""
)
-
output["result"] = "\n".join(resultLines)
except Exception as e:
output = str(e)
@@ -2380,7 +2414,7 @@ async def process_chat_response(
# background_tasks.add_task(post_response_handler, response, events)
task_id, _ = create_task(
- post_response_handler(response, events), id=metadata["chat_id"]
+ request, post_response_handler(response, events), id=metadata["chat_id"]
)
return {"status": True, "task_id": task_id}
diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py
index de33558596..6c98ed7dfa 100644
--- a/backend/open_webui/utils/oauth.py
+++ b/backend/open_webui/utils/oauth.py
@@ -538,7 +538,7 @@ class OAuthManager:
# Redirect back to the frontend with the JWT token
redirect_base_url = request.app.state.config.WEBUI_URL or request.base_url
- if redirect_base_url.endswith("/"):
+ if isinstance(redirect_base_url, str) and redirect_base_url.endswith("/"):
redirect_base_url = redirect_base_url[:-1]
redirect_url = f"{redirect_base_url}/auth#token={jwt_token}"
diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py
index 02eb0da22b..8bf705ecf0 100644
--- a/backend/open_webui/utils/payload.py
+++ b/backend/open_webui/utils/payload.py
@@ -329,3 +329,32 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
ollama_payload["format"] = format
return ollama_payload
+
+
+def convert_embedding_payload_openai_to_ollama(openai_payload: dict) -> dict:
+ """
+ Convert an embeddings request payload from OpenAI format to Ollama format.
+
+ Args:
+ openai_payload (dict): The original payload designed for OpenAI API usage.
+
+ Returns:
+ dict: A payload compatible with the Ollama API embeddings endpoint.
+ """
+ ollama_payload = {"model": openai_payload.get("model")}
+ input_value = openai_payload.get("input")
+
+ # Ollama expects 'input' as a list, and 'prompt' as a single string.
+ if isinstance(input_value, list):
+ ollama_payload["input"] = input_value
+ ollama_payload["prompt"] = "\n".join(str(x) for x in input_value)
+ else:
+ ollama_payload["input"] = [input_value]
+ ollama_payload["prompt"] = str(input_value)
+
+ # Optionally forward other fields if present
+ for optional_key in ("options", "truncate", "keep_alive"):
+ if optional_key in openai_payload:
+ ollama_payload[optional_key] = openai_payload[optional_key]
+
+ return ollama_payload
diff --git a/backend/open_webui/utils/redis.py b/backend/open_webui/utils/redis.py
index e0a53e73d1..85eae55b64 100644
--- a/backend/open_webui/utils/redis.py
+++ b/backend/open_webui/utils/redis.py
@@ -2,6 +2,7 @@ import socketio
import redis
from redis import asyncio as aioredis
from urllib.parse import urlparse
+from typing import Optional
def parse_redis_service_url(redis_url):
@@ -18,7 +19,9 @@ def parse_redis_service_url(redis_url):
}
-def get_redis_connection(redis_url, redis_sentinels, decode_responses=True):
+def get_redis_connection(
+ redis_url, redis_sentinels, decode_responses=True
+) -> Optional[redis.Redis]:
if redis_sentinels:
redis_config = parse_redis_service_url(redis_url)
sentinel = redis.sentinel.Sentinel(
@@ -32,9 +35,11 @@ def get_redis_connection(redis_url, redis_sentinels, decode_responses=True):
# Get a master connection from Sentinel
return sentinel.master_for(redis_config["service"])
- else:
+ elif redis_url:
# Standard Redis connection
return redis.Redis.from_url(redis_url, decode_responses=decode_responses)
+ else:
+ return None
def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):
diff --git a/backend/open_webui/utils/response.py b/backend/open_webui/utils/response.py
index 8c3f1a58eb..f71087e4ff 100644
--- a/backend/open_webui/utils/response.py
+++ b/backend/open_webui/utils/response.py
@@ -125,3 +125,64 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
yield line
yield "data: [DONE]\n\n"
+
+
+def convert_embedding_response_ollama_to_openai(response) -> dict:
+ """
+ Convert the response from Ollama embeddings endpoint to the OpenAI-compatible format.
+
+ Args:
+ response (dict): The response from the Ollama API,
+ e.g. {"embedding": [...], "model": "..."}
+ or {"embeddings": [{"embedding": [...], "index": 0}, ...], "model": "..."}
+
+ Returns:
+ dict: Response adapted to OpenAI's embeddings API format.
+ e.g. {
+ "object": "list",
+ "data": [
+ {"object": "embedding", "embedding": [...], "index": 0},
+ ...
+ ],
+ "model": "...",
+ }
+ """
+ # Ollama batch-style output
+ if isinstance(response, dict) and "embeddings" in response:
+ openai_data = []
+ for i, emb in enumerate(response["embeddings"]):
+ openai_data.append(
+ {
+ "object": "embedding",
+ "embedding": emb.get("embedding"),
+ "index": emb.get("index", i),
+ }
+ )
+ return {
+ "object": "list",
+ "data": openai_data,
+ "model": response.get("model"),
+ }
+ # Ollama single output
+ elif isinstance(response, dict) and "embedding" in response:
+ return {
+ "object": "list",
+ "data": [
+ {
+ "object": "embedding",
+ "embedding": response["embedding"],
+ "index": 0,
+ }
+ ],
+ "model": response.get("model"),
+ }
+ # Already OpenAI-compatible?
+ elif (
+ isinstance(response, dict)
+ and "data" in response
+ and isinstance(response["data"], list)
+ ):
+ return response
+
+ # Fallback: return as is if unrecognized
+ return response
diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py
index 95018eef18..42b44d5167 100644
--- a/backend/open_webui/utils/task.py
+++ b/backend/open_webui/utils/task.py
@@ -207,6 +207,24 @@ def title_generation_template(
return template
+def follow_up_generation_template(
+ template: str, messages: list[dict], user: Optional[dict] = None
+) -> str:
+ prompt = get_last_user_message(messages)
+ template = replace_prompt_variable(template, prompt)
+ template = replace_messages_variable(template, messages)
+
+ template = prompt_template(
+ template,
+ **(
+ {"user_name": user.get("name"), "user_location": user.get("location")}
+ if user
+ else {}
+ ),
+ )
+ return template
+
+
def tags_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str:
diff --git a/backend/requirements.txt b/backend/requirements.txt
index 9930cd3b68..c714060478 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -1,5 +1,5 @@
fastapi==0.115.7
-uvicorn[standard]==0.34.0
+uvicorn[standard]==0.34.2
pydantic==2.10.6
python-multipart==0.0.20
@@ -76,13 +76,13 @@ pandas==2.2.3
openpyxl==3.1.5
pyxlsb==1.0.10
xlrd==2.0.1
-validators==0.34.0
+validators==0.35.0
psutil
sentencepiece
soundfile==0.13.1
-azure-ai-documentintelligence==1.0.0
+azure-ai-documentintelligence==1.0.2
-pillow==11.1.0
+pillow==11.2.1
opencv-python-headless==4.11.0.86
rapidocr-onnxruntime==1.4.4
rank-bm25==0.2.2
diff --git a/backend/start.sh b/backend/start.sh
index 84d5ec8958..9e106760c8 100755
--- a/backend/start.sh
+++ b/backend/start.sh
@@ -14,7 +14,11 @@ if [[ "${WEB_LOADER_ENGINE,,}" == "playwright" ]]; then
python -c "import nltk; nltk.download('punkt_tab')"
fi
-KEY_FILE=.webui_secret_key
+if [ -n "${WEBUI_SECRET_KEY_FILE}" ]; then
+ KEY_FILE="${WEBUI_SECRET_KEY_FILE}"
+else
+ KEY_FILE=".webui_secret_key"
+fi
PORT="${PORT:-8080}"
HOST="${HOST:-0.0.0.0}"
diff --git a/backend/start_windows.bat b/backend/start_windows.bat
index 8d9aae3ac6..e38fdb2aa6 100644
--- a/backend/start_windows.bat
+++ b/backend/start_windows.bat
@@ -18,6 +18,10 @@ IF /I "%WEB_LOADER_ENGINE%" == "playwright" (
)
SET "KEY_FILE=.webui_secret_key"
+IF NOT "%WEBUI_SECRET_KEY_FILE%" == "" (
+ SET "KEY_FILE=%WEBUI_SECRET_KEY_FILE%"
+)
+
IF "%PORT%"=="" SET PORT=8080
IF "%HOST%"=="" SET HOST=0.0.0.0
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
diff --git a/package-lock.json b/package-lock.json
index ae602efa0e..fbe35065e1 100644
--- a/package-lock.json
+++ b/package-lock.json
@@ -51,7 +51,7 @@
"idb": "^7.1.1",
"js-sha256": "^0.10.1",
"jspdf": "^3.0.0",
- "katex": "^0.16.21",
+ "katex": "^0.16.22",
"kokoro-js": "^1.1.1",
"marked": "^9.1.0",
"mermaid": "^11.6.0",
@@ -7930,9 +7930,9 @@
}
},
"node_modules/katex": {
- "version": "0.16.21",
- "resolved": "https://registry.npmjs.org/katex/-/katex-0.16.21.tgz",
- "integrity": "sha512-XvqR7FgOHtWupfMiigNzmh+MgUVmDGU2kXZm899ZkPfcuoPuFxyHmXsgATDpFZDAXCI8tvinaVcDo8PIIJSo4A==",
+ "version": "0.16.22",
+ "resolved": "https://registry.npmjs.org/katex/-/katex-0.16.22.tgz",
+ "integrity": "sha512-XCHRdUw4lf3SKBaJe4EvgqIuWwkPSo9XoeO8GjQW94Bp7TWv9hNhzZjZ+OH9yf1UmLygb7DIT5GSFQiyt16zYg==",
"funding": [
"https://opencollective.com/katex",
"https://github.com/sponsors/katex"
diff --git a/package.json b/package.json
index b2a71845ed..8737788a92 100644
--- a/package.json
+++ b/package.json
@@ -95,7 +95,7 @@
"idb": "^7.1.1",
"js-sha256": "^0.10.1",
"jspdf": "^3.0.0",
- "katex": "^0.16.21",
+ "katex": "^0.16.22",
"kokoro-js": "^1.1.1",
"marked": "^9.1.0",
"mermaid": "^11.6.0",
diff --git a/scripts/prepare-pyodide.js b/scripts/prepare-pyodide.js
index 70f3cf5c6c..664683a30d 100644
--- a/scripts/prepare-pyodide.js
+++ b/scripts/prepare-pyodide.js
@@ -12,7 +12,8 @@ const packages = [
'sympy',
'tiktoken',
'seaborn',
- 'pytz'
+ 'pytz',
+ 'black'
];
import { loadPyodide } from 'pyodide';
diff --git a/src/app.css b/src/app.css
index ea0bd5fb0a..352a18d213 100644
--- a/src/app.css
+++ b/src/app.css
@@ -44,6 +44,10 @@ code {
font-family: 'InstrumentSerif', sans-serif;
}
+.marked a {
+ @apply underline;
+}
+
math {
margin-top: 1rem;
}
diff --git a/src/lib/apis/auths/index.ts b/src/lib/apis/auths/index.ts
index 169a6c14fc..842edd9c9d 100644
--- a/src/lib/apis/auths/index.ts
+++ b/src/lib/apis/auths/index.ts
@@ -336,7 +336,7 @@ export const userSignOut = async () => {
})
.then(async (res) => {
if (!res.ok) throw await res.json();
- return res;
+ return res.json();
})
.catch((err) => {
console.error(err);
diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts
index 268be397bc..8e4c78aec3 100644
--- a/src/lib/apis/index.ts
+++ b/src/lib/apis/index.ts
@@ -612,6 +612,78 @@ export const generateTitle = async (
}
};
+export const generateFollowUps = async (
+ token: string = '',
+ model: string,
+ messages: string,
+ chat_id?: string
+) => {
+ let error = null;
+
+ const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/follow_ups/completions`, {
+ method: 'POST',
+ headers: {
+ Accept: 'application/json',
+ 'Content-Type': 'application/json',
+ Authorization: `Bearer ${token}`
+ },
+ body: JSON.stringify({
+ model: model,
+ messages: messages,
+ ...(chat_id && { chat_id: chat_id })
+ })
+ })
+ .then(async (res) => {
+ if (!res.ok) throw await res.json();
+ return res.json();
+ })
+ .catch((err) => {
+ console.error(err);
+ if ('detail' in err) {
+ error = err.detail;
+ }
+ return null;
+ });
+
+ if (error) {
+ throw error;
+ }
+
+ try {
+ // Step 1: Safely extract the response string
+ const response = res?.choices[0]?.message?.content ?? '';
+
+ // Step 2: Attempt to fix common JSON format issues like single quotes
+ const sanitizedResponse = response.replace(/['‘’`]/g, '"'); // Convert single quotes to double quotes for valid JSON
+
+ // Step 3: Find the relevant JSON block within the response
+ const jsonStartIndex = sanitizedResponse.indexOf('{');
+ const jsonEndIndex = sanitizedResponse.lastIndexOf('}');
+
+ // Step 4: Check if we found a valid JSON block (with both `{` and `}`)
+ if (jsonStartIndex !== -1 && jsonEndIndex !== -1) {
+ const jsonResponse = sanitizedResponse.substring(jsonStartIndex, jsonEndIndex + 1);
+
+ // Step 5: Parse the JSON block
+ const parsed = JSON.parse(jsonResponse);
+
+ // Step 6: If there's a "follow_ups" key, return the follow_ups array; otherwise, return an empty array
+ if (parsed && parsed.follow_ups) {
+ return Array.isArray(parsed.follow_ups) ? parsed.follow_ups : [];
+ } else {
+ return [];
+ }
+ }
+
+ // If no valid JSON block found, return an empty array
+ return [];
+ } catch (e) {
+ // Catch and safely return empty array on any parsing errors
+ console.error('Failed to parse response: ', e);
+ return [];
+ }
+};
+
export const generateTags = async (
token: string = '',
model: string,
diff --git a/src/lib/apis/users/index.ts b/src/lib/apis/users/index.ts
index f8ab88ff53..391bdca56d 100644
--- a/src/lib/apis/users/index.ts
+++ b/src/lib/apis/users/index.ts
@@ -393,6 +393,7 @@ export const updateUserById = async (token: string, userId: string, user: UserUp
},
body: JSON.stringify({
profile_image_url: user.profile_image_url,
+ role: user.role,
email: user.email,
name: user.name,
password: user.password !== '' ? user.password : undefined
diff --git a/src/lib/components/AddConnectionModal.svelte b/src/lib/components/AddConnectionModal.svelte
index 8a708d4d2e..2104d8f939 100644
--- a/src/lib/components/AddConnectionModal.svelte
+++ b/src/lib/components/AddConnectionModal.svelte
@@ -49,6 +49,9 @@
let loading = false;
const verifyOllamaHandler = async () => {
+ // remove trailing slash from url
+ url = url.replace(/\/$/, '');
+
const res = await verifyOllamaConnection(localStorage.token, {
url,
key
@@ -62,6 +65,9 @@
};
const verifyOpenAIHandler = async () => {
+ // remove trailing slash from url
+ url = url.replace(/\/$/, '');
+
const res = await verifyOpenAIConnection(
localStorage.token,
{
diff --git a/src/lib/components/admin/Evaluations.svelte b/src/lib/components/admin/Evaluations.svelte
index a5532ae2f2..d223db57ce 100644
--- a/src/lib/components/admin/Evaluations.svelte
+++ b/src/lib/components/admin/Evaluations.svelte
@@ -1,6 +1,8 @@
@@ -37,12 +59,13 @@
class="tabs flex flex-row overflow-x-auto gap-2.5 max-w-full lg:gap-1 lg:flex-col lg:flex-none lg:w-40 dark:text-gray-200 text-sm font-medium text-left scrollbar-none"
>
|