diff --git a/.gitattributes b/.gitattributes index 526c8a38d4..bf368a4c6c 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,49 @@ -*.sh text eol=lf \ No newline at end of file +# TypeScript +*.ts text eol=lf +*.tsx text eol=lf + +# JavaScript +*.js text eol=lf +*.jsx text eol=lf +*.mjs text eol=lf +*.cjs text eol=lf + +# Svelte +*.svelte text eol=lf + +# HTML/CSS +*.html text eol=lf +*.css text eol=lf +*.scss text eol=lf +*.less text eol=lf + +# Config files and JSON +*.json text eol=lf +*.jsonc text eol=lf +*.yml text eol=lf +*.yaml text eol=lf +*.toml text eol=lf + +# Shell scripts +*.sh text eol=lf + +# Markdown & docs +*.md text eol=lf +*.mdx text eol=lf +*.txt text eol=lf + +# Git-related +.gitattributes text eol=lf +.gitignore text eol=lf + +# Prettier and other dotfiles +.prettierrc text eol=lf +.prettierignore text eol=lf +.eslintrc text eol=lf +.eslintignore text eol=lf +.stylelintrc text eol=lf +.editorconfig text eol=lf + +# Misc +*.env text eol=lf +*.lock text eol=lf \ No newline at end of file diff --git a/.github/workflows/docker-build.yaml b/.github/workflows/docker-build.yaml index e61a69f33a..821ffb7206 100644 --- a/.github/workflows/docker-build.yaml +++ b/.github/workflows/docker-build.yaml @@ -14,16 +14,18 @@ env: jobs: build-main-image: - runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }} + runs-on: ${{ matrix.runner }} permissions: contents: read packages: write strategy: fail-fast: false matrix: - platform: - - linux/amd64 - - linux/arm64 + include: + - platform: linux/amd64 + runner: ubuntu-latest + - platform: linux/arm64 + runner: ubuntu-24.04-arm steps: # GitHub Packages requires the entire repository name to be in lowercase @@ -111,16 +113,18 @@ jobs: retention-days: 1 build-cuda-image: - runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }} + runs-on: ${{ matrix.runner }} permissions: contents: read packages: write strategy: fail-fast: false matrix: - platform: - - linux/amd64 - - linux/arm64 + include: + - platform: linux/amd64 + runner: ubuntu-latest + - platform: linux/arm64 + runner: ubuntu-24.04-arm steps: # GitHub Packages requires the entire repository name to be in lowercase @@ -210,17 +214,122 @@ jobs: if-no-files-found: error retention-days: 1 - build-ollama-image: - runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }} + build-cuda126-image: + runs-on: ${{ matrix.runner }} permissions: contents: read packages: write strategy: fail-fast: false matrix: - platform: - - linux/amd64 - - linux/arm64 + include: + - platform: linux/amd64 + runner: ubuntu-latest + - platform: linux/arm64 + runner: ubuntu-24.04-arm + + steps: + # GitHub Packages requires the entire repository name to be in lowercase + # although the repository owner has a lowercase username, this prevents some people from running actions after forking + - name: Set repository and image name to lowercase + run: | + echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV} + echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV} + env: + IMAGE_NAME: '${{ github.repository }}' + + - name: Prepare + run: | + platform=${{ matrix.platform }} + echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV + + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to the Container registry + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata for Docker images (cuda126 tag) + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.FULL_IMAGE_NAME }} + tags: | + type=ref,event=branch + type=ref,event=tag + type=sha,prefix=git- + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=cuda126 + flavor: | + latest=${{ github.ref == 'refs/heads/main' }} + suffix=-cuda126,onlatest=true + + - name: Extract metadata for Docker cache + id: cache-meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.FULL_IMAGE_NAME }} + tags: | + type=ref,event=branch + ${{ github.ref_type == 'tag' && 'type=raw,value=main' || '' }} + flavor: | + prefix=cache-cuda126-${{ matrix.platform }}- + latest=false + + - name: Build Docker image (cuda126) + uses: docker/build-push-action@v5 + id: build + with: + context: . + push: true + platforms: ${{ matrix.platform }} + labels: ${{ steps.meta.outputs.labels }} + outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true + cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }} + cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max + build-args: | + BUILD_HASH=${{ github.sha }} + USE_CUDA=true + USE_CUDA_VER=cu126 + + - name: Export digest + run: | + mkdir -p /tmp/digests + digest="${{ steps.build.outputs.digest }}" + touch "/tmp/digests/${digest#sha256:}" + + - name: Upload digest + uses: actions/upload-artifact@v4 + with: + name: digests-cuda126-${{ env.PLATFORM_PAIR }} + path: /tmp/digests/* + if-no-files-found: error + retention-days: 1 + + build-ollama-image: + runs-on: ${{ matrix.runner }} + permissions: + contents: read + packages: write + strategy: + fail-fast: false + matrix: + include: + - platform: linux/amd64 + runner: ubuntu-latest + - platform: linux/arm64 + runner: ubuntu-24.04-arm steps: # GitHub Packages requires the entire repository name to be in lowercase @@ -420,6 +529,62 @@ jobs: run: | docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }} + merge-cuda126-images: + runs-on: ubuntu-latest + needs: [build-cuda126-image] + steps: + # GitHub Packages requires the entire repository name to be in lowercase + # although the repository owner has a lowercase username, this prevents some people from running actions after forking + - name: Set repository and image name to lowercase + run: | + echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV} + echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV} + env: + IMAGE_NAME: '${{ github.repository }}' + + - name: Download digests + uses: actions/download-artifact@v4 + with: + pattern: digests-cuda126-* + path: /tmp/digests + merge-multiple: true + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to the Container registry + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata for Docker images (default latest tag) + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.FULL_IMAGE_NAME }} + tags: | + type=ref,event=branch + type=ref,event=tag + type=sha,prefix=git- + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=cuda126 + flavor: | + latest=${{ github.ref == 'refs/heads/main' }} + suffix=-cuda126,onlatest=true + + - name: Create manifest list and push + working-directory: /tmp/digests + run: | + docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ + $(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *) + + - name: Inspect image + run: | + docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }} + merge-ollama-images: runs-on: ubuntu-latest needs: [build-ollama-image] diff --git a/.prettierrc b/.prettierrc index a77fddea90..22558729f4 100644 --- a/.prettierrc +++ b/.prettierrc @@ -5,5 +5,6 @@ "printWidth": 100, "plugins": ["prettier-plugin-svelte"], "pluginSearchDirs": ["."], - "overrides": [{ "files": "*.svelte", "options": { "parser": "svelte" } }] + "overrides": [{ "files": "*.svelte", "options": { "parser": "svelte" } }], + "endOfLine": "lf" } diff --git a/README.md b/README.md index 2ad208c3f7..ea1f2acbbe 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,7 @@ Want to learn more about Open WebUI's features? Check out our [Open WebUI docume - n8n + Warp diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 0f49483610..0c7dc3d521 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -901,9 +901,7 @@ TOOL_SERVER_CONNECTIONS = PersistentConfig( #################################### -WEBUI_URL = PersistentConfig( - "WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", "http://localhost:3000") -) +WEBUI_URL = PersistentConfig("WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", "")) ENABLE_SIGNUP = PersistentConfig( @@ -1413,6 +1411,35 @@ Strictly return in JSON format: {{MESSAGES:END:6}} """ + +FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE", + "task.follow_up.prompt_template", + os.environ.get("FOLLOW_UP_GENERATION_PROMPT_TEMPLATE", ""), +) + +DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = """### Task: +Suggest 3-5 relevant follow-up questions or prompts that the user might naturally ask next in this conversation as a **user**, based on the chat history, to help continue or deepen the discussion. +### Guidelines: +- Write all follow-up questions from the user’s point of view, directed to the assistant. +- Make questions concise, clear, and directly related to the discussed topic(s). +- Only suggest follow-ups that make sense given the chat content and do not repeat what was already covered. +- If the conversation is very short or not specific, suggest more general (but relevant) follow-ups the user might ask. +- Use the conversation's primary language; default to English if multilingual. +- Response must be a JSON array of strings, no extra text or formatting. +### Output: +JSON format: { "follow_ups": ["Question 1?", "Question 2?", "Question 3?"] } +### Chat History: + +{{MESSAGES:END:6}} +""" + +ENABLE_FOLLOW_UP_GENERATION = PersistentConfig( + "ENABLE_FOLLOW_UP_GENERATION", + "task.follow_up.enable", + os.environ.get("ENABLE_FOLLOW_UP_GENERATION", "True").lower() == "true", +) + ENABLE_TAGS_GENERATION = PersistentConfig( "ENABLE_TAGS_GENERATION", "task.tags.enable", @@ -2444,6 +2471,18 @@ PERPLEXITY_API_KEY = PersistentConfig( os.getenv("PERPLEXITY_API_KEY", ""), ) +PERPLEXITY_MODEL = PersistentConfig( + "PERPLEXITY_MODEL", + "rag.web.search.perplexity_model", + os.getenv("PERPLEXITY_MODEL", "sonar"), +) + +PERPLEXITY_SEARCH_CONTEXT_USAGE = PersistentConfig( + "PERPLEXITY_SEARCH_CONTEXT_USAGE", + "rag.web.search.perplexity_search_context_usage", + os.getenv("PERPLEXITY_SEARCH_CONTEXT_USAGE", "medium"), +) + SOUGOU_API_SID = PersistentConfig( "SOUGOU_API_SID", "rag.web.search.sougou_api_sid", diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py index 95c54a0d27..59ee6aaacb 100644 --- a/backend/open_webui/constants.py +++ b/backend/open_webui/constants.py @@ -111,6 +111,7 @@ class TASKS(str, Enum): DEFAULT = lambda task="": f"{task if task else 'generation'}" TITLE_GENERATION = "title_generation" + FOLLOW_UP_GENERATION = "follow_up_generation" TAGS_GENERATION = "tags_generation" EMOJI_GENERATION = "emoji_generation" QUERY_GENERATION = "query_generation" diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 20fabb2dc7..6eb5c1bbdb 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -25,6 +25,7 @@ from open_webui.socket.main import ( ) +from open_webui.models.users import UserModel from open_webui.models.functions import Functions from open_webui.models.models import Models @@ -227,12 +228,7 @@ async def generate_function_chat_completion( "__task__": __task__, "__task_body__": __task_body__, "__files__": files, - "__user__": { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, + "__user__": user.model_dump() if isinstance(user, UserModel) else {}, "__metadata__": metadata, "__request__": request, } diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 6bdcf4957a..02d5b0d018 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -37,7 +37,7 @@ from fastapi import ( from fastapi.openapi.docs import get_swagger_ui_html from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, RedirectResponse +from fastapi.responses import FileResponse, JSONResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from starlette_compress import CompressMiddleware @@ -268,6 +268,8 @@ from open_webui.config import ( BRAVE_SEARCH_API_KEY, EXA_API_KEY, PERPLEXITY_API_KEY, + PERPLEXITY_MODEL, + PERPLEXITY_SEARCH_CONTEXT_USAGE, SOUGOU_API_SID, SOUGOU_API_SK, KAGI_SEARCH_API_KEY, @@ -359,10 +361,12 @@ from open_webui.config import ( TASK_MODEL_EXTERNAL, ENABLE_TAGS_GENERATION, ENABLE_TITLE_GENERATION, + ENABLE_FOLLOW_UP_GENERATION, ENABLE_SEARCH_QUERY_GENERATION, ENABLE_RETRIEVAL_QUERY_GENERATION, ENABLE_AUTOCOMPLETE_GENERATION, TITLE_GENERATION_PROMPT_TEMPLATE, + FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, TAGS_GENERATION_PROMPT_TEMPLATE, IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, @@ -411,6 +415,7 @@ from open_webui.utils.chat import ( chat_completed as chat_completed_handler, chat_action as chat_action_handler, ) +from open_webui.utils.embeddings import generate_embeddings from open_webui.utils.middleware import process_chat_payload, process_chat_response from open_webui.utils.access_control import has_access @@ -771,6 +776,8 @@ app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY app.state.config.EXA_API_KEY = EXA_API_KEY app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY +app.state.config.PERPLEXITY_MODEL = PERPLEXITY_MODEL +app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = PERPLEXITY_SEARCH_CONTEXT_USAGE app.state.config.SOUGOU_API_SID = SOUGOU_API_SID app.state.config.SOUGOU_API_SK = SOUGOU_API_SK app.state.config.EXTERNAL_WEB_SEARCH_URL = EXTERNAL_WEB_SEARCH_URL @@ -959,6 +966,7 @@ app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENE app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION app.state.config.ENABLE_TITLE_GENERATION = ENABLE_TITLE_GENERATION +app.state.config.ENABLE_FOLLOW_UP_GENERATION = ENABLE_FOLLOW_UP_GENERATION app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE @@ -966,6 +974,9 @@ app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLA app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = ( IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE ) +app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = ( + FOLLOW_UP_GENERATION_PROMPT_TEMPLATE +) app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE @@ -1197,6 +1208,37 @@ async def get_base_models(request: Request, user=Depends(get_admin_user)): return {"data": models} +################################## +# Embeddings +################################## + + +@app.post("/api/embeddings") +async def embeddings( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + """ + OpenAI-compatible embeddings endpoint. + + This handler: + - Performs user/model checks and dispatches to the correct backend. + - Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider. + + Args: + request (Request): Request context. + form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]}) + user (UserModel): Authenticated user. + + Returns: + dict: OpenAI-compatible embeddings response. + """ + # Make sure models are loaded in app state + if not request.app.state.MODELS: + await get_all_models(request, user=user) + # Use generic dispatcher in utils.embeddings + return await generate_embeddings(request, form_data, user) + + @app.post("/api/chat/completions") async def chat_completion( request: Request, @@ -1628,7 +1670,20 @@ async def healthcheck_with_db(): app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") -app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache") + + +@app.get("/cache/{path:path}") +async def serve_cache_file( + path: str, + user=Depends(get_verified_user), +): + file_path = os.path.abspath(os.path.join(CACHE_DIR, path)) + # prevent path traversal + if not file_path.startswith(os.path.abspath(CACHE_DIR)): + raise HTTPException(status_code=404, detail="File not found") + if not os.path.isfile(file_path): + raise HTTPException(status_code=404, detail="File not found") + return FileResponse(file_path) def swagger_ui_html(*args, **kwargs): diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index 3222aa27a6..a5dd9467bc 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -95,6 +95,7 @@ class UserRoleUpdateForm(BaseModel): class UserUpdateForm(BaseModel): + role: str name: str email: str profile_image_url: str diff --git a/backend/open_webui/retrieval/loaders/main.py b/backend/open_webui/retrieval/loaders/main.py index 0d0ff851b7..103a9dc935 100644 --- a/backend/open_webui/retrieval/loaders/main.py +++ b/backend/open_webui/retrieval/loaders/main.py @@ -76,7 +76,6 @@ known_source_ext = [ "swift", "vue", "svelte", - "msg", "ex", "exs", "erl", @@ -147,15 +146,12 @@ class DoclingLoader: ) } - params = { - "image_export_mode": "placeholder", - "table_mode": "accurate", - } + params = {"image_export_mode": "placeholder", "table_mode": "accurate"} if self.params: - if self.params.get("do_picture_classification"): - params["do_picture_classification"] = self.params.get( - "do_picture_classification" + if self.params.get("do_picture_description"): + params["do_picture_description"] = self.params.get( + "do_picture_description" ) if self.params.get("ocr_engine") and self.params.get("ocr_lang"): @@ -292,7 +288,7 @@ class Loader: params={ "ocr_engine": self.kwargs.get("DOCLING_OCR_ENGINE"), "ocr_lang": self.kwargs.get("DOCLING_OCR_LANG"), - "do_picture_classification": self.kwargs.get( + "do_picture_description": self.kwargs.get( "DOCLING_DO_PICTURE_DESCRIPTION" ), }, diff --git a/backend/open_webui/retrieval/loaders/mistral.py b/backend/open_webui/retrieval/loaders/mistral.py index 67641d0509..b00e9d7ce5 100644 --- a/backend/open_webui/retrieval/loaders/mistral.py +++ b/backend/open_webui/retrieval/loaders/mistral.py @@ -20,6 +20,14 @@ class MistralLoader: """ Enhanced Mistral OCR loader with both sync and async support. Loads documents by processing them through the Mistral OCR API. + + Performance Optimizations: + - Differentiated timeouts for different operations + - Intelligent retry logic with exponential backoff + - Memory-efficient file streaming for large files + - Connection pooling and keepalive optimization + - Semaphore-based concurrency control for batch processing + - Enhanced error handling with retryable error classification """ BASE_API_URL = "https://api.mistral.ai/v1" @@ -53,17 +61,40 @@ class MistralLoader: self.max_retries = max_retries self.debug = enable_debug_logging - # Pre-compute file info for performance + # PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations + # This prevents long-running OCR operations from affecting quick operations + # and improves user experience by failing fast on operations that should be quick + self.upload_timeout = min( + timeout, 120 + ) # Cap upload at 2 minutes - prevents hanging on large files + self.url_timeout = ( + 30 # URL requests should be fast - fail quickly if API is slow + ) + self.ocr_timeout = ( + timeout # OCR can take the full timeout - this is the heavy operation + ) + self.cleanup_timeout = ( + 30 # Cleanup should be quick - don't hang on file deletion + ) + + # PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls + # This avoids multiple os.path.basename() and os.path.getsize() calls during processing self.file_name = os.path.basename(file_path) self.file_size = os.path.getsize(file_path) + # ENHANCEMENT: Added User-Agent for better API tracking and debugging self.headers = { "Authorization": f"Bearer {self.api_key}", - "User-Agent": "OpenWebUI-MistralLoader/2.0", + "User-Agent": "OpenWebUI-MistralLoader/2.0", # Helps API provider track usage } def _debug_log(self, message: str, *args) -> None: - """Conditional debug logging for performance.""" + """ + PERFORMANCE OPTIMIZATION: Conditional debug logging for performance. + + Only processes debug messages when debug mode is enabled, avoiding + string formatting overhead in production environments. + """ if self.debug: log.debug(message, *args) @@ -115,53 +146,118 @@ class MistralLoader: log.error(f"Unexpected error processing response: {e}") raise + def _is_retryable_error(self, error: Exception) -> bool: + """ + ENHANCEMENT: Intelligent error classification for retry logic. + + Determines if an error is retryable based on its type and status code. + This prevents wasting time retrying errors that will never succeed + (like authentication errors) while ensuring transient errors are retried. + + Retryable errors: + - Network connection errors (temporary network issues) + - Timeouts (server might be temporarily overloaded) + - Server errors (5xx status codes - server-side issues) + - Rate limiting (429 status - temporary throttling) + + Non-retryable errors: + - Authentication errors (401, 403 - won't fix with retry) + - Bad request errors (400 - malformed request) + - Not found errors (404 - resource doesn't exist) + """ + if isinstance(error, requests.exceptions.ConnectionError): + return True # Network issues are usually temporary + if isinstance(error, requests.exceptions.Timeout): + return True # Timeouts might resolve on retry + if isinstance(error, requests.exceptions.HTTPError): + # Only retry on server errors (5xx) or rate limits (429) + if hasattr(error, "response") and error.response is not None: + status_code = error.response.status_code + return status_code >= 500 or status_code == 429 + return False + if isinstance( + error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError) + ): + return True # Async network/timeout errors are retryable + if isinstance(error, aiohttp.ClientResponseError): + return error.status >= 500 or error.status == 429 + return False # All other errors are non-retryable + def _retry_request_sync(self, request_func, *args, **kwargs): - """Synchronous retry logic with exponential backoff.""" + """ + ENHANCEMENT: Synchronous retry logic with intelligent error classification. + + Uses exponential backoff with jitter to avoid thundering herd problems. + The wait time increases exponentially but is capped at 30 seconds to + prevent excessive delays. Only retries errors that are likely to succeed + on subsequent attempts. + """ for attempt in range(self.max_retries): try: return request_func(*args, **kwargs) - except (requests.exceptions.RequestException, Exception) as e: - if attempt == self.max_retries - 1: + except Exception as e: + if attempt == self.max_retries - 1 or not self._is_retryable_error(e): raise - wait_time = (2**attempt) + 0.5 + # PERFORMANCE OPTIMIZATION: Exponential backoff with cap + # Prevents overwhelming the server while ensuring reasonable retry delays + wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds log.warning( - f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..." + f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. " + f"Retrying in {wait_time}s..." ) time.sleep(wait_time) async def _retry_request_async(self, request_func, *args, **kwargs): - """Async retry logic with exponential backoff.""" + """ + ENHANCEMENT: Async retry logic with intelligent error classification. + + Async version of retry logic that doesn't block the event loop during + wait periods. Uses the same exponential backoff strategy as sync version. + """ for attempt in range(self.max_retries): try: return await request_func(*args, **kwargs) - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - if attempt == self.max_retries - 1: + except Exception as e: + if attempt == self.max_retries - 1 or not self._is_retryable_error(e): raise - wait_time = (2**attempt) + 0.5 + # PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff + wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds log.warning( - f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..." + f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. " + f"Retrying in {wait_time}s..." ) - await asyncio.sleep(wait_time) + await asyncio.sleep(wait_time) # Non-blocking wait def _upload_file(self) -> str: - """Uploads the file to Mistral for OCR processing (sync version).""" + """ + PERFORMANCE OPTIMIZATION: Enhanced file upload with streaming consideration. + + Uploads the file to Mistral for OCR processing (sync version). + Uses context manager for file handling to ensure proper resource cleanup. + Although streaming is not enabled for this endpoint, the file is opened + in a context manager to minimize memory usage duration. + """ log.info("Uploading file to Mistral API") url = f"{self.BASE_API_URL}/files" - file_name = os.path.basename(self.file_path) def upload_request(): + # MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime + # This ensures the file is closed immediately after reading, reducing memory usage with open(self.file_path, "rb") as f: - files = {"file": (file_name, f, "application/pdf")} + files = {"file": (self.file_name, f, "application/pdf")} data = {"purpose": "ocr"} + # NOTE: stream=False is required for this endpoint + # The Mistral API doesn't support chunked uploads for this endpoint response = requests.post( url, headers=self.headers, files=files, data=data, - timeout=self.timeout, + timeout=self.upload_timeout, # Use specialized upload timeout + stream=False, # Keep as False for this endpoint ) return self._handle_response(response) @@ -209,7 +305,7 @@ class MistralLoader: url, data=writer, headers=self.headers, - timeout=aiohttp.ClientTimeout(total=self.timeout), + timeout=aiohttp.ClientTimeout(total=self.upload_timeout), ) as response: return await self._handle_response_async(response) @@ -231,7 +327,7 @@ class MistralLoader: def url_request(): response = requests.get( - url, headers=signed_url_headers, params=params, timeout=self.timeout + url, headers=signed_url_headers, params=params, timeout=self.url_timeout ) return self._handle_response(response) @@ -261,7 +357,7 @@ class MistralLoader: url, headers=headers, params=params, - timeout=aiohttp.ClientTimeout(total=self.timeout), + timeout=aiohttp.ClientTimeout(total=self.url_timeout), ) as response: return await self._handle_response_async(response) @@ -294,7 +390,7 @@ class MistralLoader: def ocr_request(): response = requests.post( - url, headers=ocr_headers, json=payload, timeout=self.timeout + url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout ) return self._handle_response(response) @@ -336,7 +432,7 @@ class MistralLoader: url, json=payload, headers=headers, - timeout=aiohttp.ClientTimeout(total=self.timeout), + timeout=aiohttp.ClientTimeout(total=self.ocr_timeout), ) as response: ocr_response = await self._handle_response_async(response) @@ -353,7 +449,9 @@ class MistralLoader: url = f"{self.BASE_API_URL}/files/{file_id}" try: - response = requests.delete(url, headers=self.headers, timeout=30) + response = requests.delete( + url, headers=self.headers, timeout=self.cleanup_timeout + ) delete_response = self._handle_response(response) log.info(f"File deleted successfully: {delete_response}") except Exception as e: @@ -372,7 +470,7 @@ class MistralLoader: url=f"{self.BASE_API_URL}/files/{file_id}", headers=self.headers, timeout=aiohttp.ClientTimeout( - total=30 + total=self.cleanup_timeout ), # Shorter timeout for cleanup ) as response: return await self._handle_response_async(response) @@ -388,29 +486,39 @@ class MistralLoader: async def _get_session(self): """Context manager for HTTP session with optimized settings.""" connector = aiohttp.TCPConnector( - limit=10, # Total connection limit - limit_per_host=5, # Per-host connection limit - ttl_dns_cache=300, # DNS cache TTL + limit=20, # Increased total connection limit for better throughput + limit_per_host=10, # Increased per-host limit for API endpoints + ttl_dns_cache=600, # Longer DNS cache TTL (10 minutes) use_dns_cache=True, - keepalive_timeout=30, + keepalive_timeout=60, # Increased keepalive for connection reuse enable_cleanup_closed=True, + force_close=False, # Allow connection reuse + resolver=aiohttp.AsyncResolver(), # Use async DNS resolver + ) + + timeout = aiohttp.ClientTimeout( + total=self.timeout, + connect=30, # Connection timeout + sock_read=60, # Socket read timeout ) async with aiohttp.ClientSession( connector=connector, - timeout=aiohttp.ClientTimeout(total=self.timeout), + timeout=timeout, headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"}, + raise_for_status=False, # We handle status codes manually ) as session: yield session def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]: - """Process OCR results into Document objects with enhanced metadata.""" + """Process OCR results into Document objects with enhanced metadata and memory efficiency.""" pages_data = ocr_response.get("pages") if not pages_data: log.warning("No pages found in OCR response.") return [ Document( - page_content="No text content found", metadata={"error": "no_pages"} + page_content="No text content found", + metadata={"error": "no_pages", "file_name": self.file_name}, ) ] @@ -418,41 +526,44 @@ class MistralLoader: total_pages = len(pages_data) skipped_pages = 0 + # Process pages in a memory-efficient way for page_data in pages_data: page_content = page_data.get("markdown") page_index = page_data.get("index") # API uses 0-based index - if page_content is not None and page_index is not None: - # Clean up content efficiently - cleaned_content = ( - page_content.strip() - if isinstance(page_content, str) - else str(page_content) - ) - - if cleaned_content: # Only add non-empty pages - documents.append( - Document( - page_content=cleaned_content, - metadata={ - "page": page_index, # 0-based index from API - "page_label": page_index - + 1, # 1-based label for convenience - "total_pages": total_pages, - "file_name": self.file_name, - "file_size": self.file_size, - "processing_engine": "mistral-ocr", - }, - ) - ) - else: - skipped_pages += 1 - self._debug_log(f"Skipping empty page {page_index}") - else: + if page_content is None or page_index is None: skipped_pages += 1 self._debug_log( - f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}" + f"Skipping page due to missing 'markdown' or 'index'. Data keys: {list(page_data.keys())}" ) + continue + + # Clean up content efficiently with early exit for empty content + if isinstance(page_content, str): + cleaned_content = page_content.strip() + else: + cleaned_content = str(page_content).strip() + + if not cleaned_content: + skipped_pages += 1 + self._debug_log(f"Skipping empty page {page_index}") + continue + + # Create document with optimized metadata + documents.append( + Document( + page_content=cleaned_content, + metadata={ + "page": page_index, # 0-based index from API + "page_label": page_index + 1, # 1-based label for convenience + "total_pages": total_pages, + "file_name": self.file_name, + "file_size": self.file_size, + "processing_engine": "mistral-ocr", + "content_length": len(cleaned_content), + }, + ) + ) if skipped_pages > 0: log.info( @@ -467,7 +578,11 @@ class MistralLoader: return [ Document( page_content="No valid text content found in document", - metadata={"error": "no_valid_pages", "total_pages": total_pages}, + metadata={ + "error": "no_valid_pages", + "total_pages": total_pages, + "file_name": self.file_name, + }, ) ] @@ -585,12 +700,14 @@ class MistralLoader: @staticmethod async def load_multiple_async( loaders: List["MistralLoader"], + max_concurrent: int = 5, # Limit concurrent requests ) -> List[List[Document]]: """ - Process multiple files concurrently for maximum performance. + Process multiple files concurrently with controlled concurrency. Args: loaders: List of MistralLoader instances + max_concurrent: Maximum number of concurrent requests Returns: List of document lists, one for each loader @@ -598,11 +715,20 @@ class MistralLoader: if not loaders: return [] - log.info(f"Starting concurrent processing of {len(loaders)} files") + log.info( + f"Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent" + ) start_time = time.time() - # Process all files concurrently - tasks = [loader.load_async() for loader in loaders] + # Use semaphore to control concurrency + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_with_semaphore(loader: "MistralLoader") -> List[Document]: + async with semaphore: + return await loader.load_async() + + # Process all files with controlled concurrency + tasks = [process_with_semaphore(loader) for loader in loaders] results = await asyncio.gather(*tasks, return_exceptions=True) # Handle any exceptions in results @@ -624,10 +750,18 @@ class MistralLoader: else: processed_results.append(result) + # MONITORING: Log comprehensive batch processing statistics total_time = time.time() - start_time total_docs = sum(len(docs) for docs in processed_results) + success_count = sum( + 1 for result in results if not isinstance(result, Exception) + ) + failure_count = len(results) - success_count + log.info( - f"Batch processing completed in {total_time:.2f}s, produced {total_docs} total documents" + f"Batch processing completed in {total_time:.2f}s: " + f"{success_count} files succeeded, {failure_count} files failed, " + f"produced {total_docs} total documents" ) return processed_results diff --git a/backend/open_webui/retrieval/vector/dbs/pinecone.py b/backend/open_webui/retrieval/vector/dbs/pinecone.py index 9f8abf4609..8291332c0f 100644 --- a/backend/open_webui/retrieval/vector/dbs/pinecone.py +++ b/backend/open_webui/retrieval/vector/dbs/pinecone.py @@ -3,10 +3,19 @@ import logging import time # for measuring elapsed time from pinecone import Pinecone, ServerlessSpec +# Add gRPC support for better performance (Pinecone best practice) +try: + from pinecone.grpc import PineconeGRPC + + GRPC_AVAILABLE = True +except ImportError: + GRPC_AVAILABLE = False + import asyncio # for async upserts import functools # for partial binding in async tasks import concurrent.futures # for parallel batch upserts +import random # for jitter in retry backoff from open_webui.retrieval.vector.main import ( VectorDBBase, @@ -47,7 +56,24 @@ class PineconeClient(VectorDBBase): self.cloud = PINECONE_CLOUD # Initialize Pinecone client for improved performance - self.client = Pinecone(api_key=self.api_key) + if GRPC_AVAILABLE: + # Use gRPC client for better performance (Pinecone recommendation) + self.client = PineconeGRPC( + api_key=self.api_key, + pool_threads=20, # Improved connection pool size + timeout=30, # Reasonable timeout for operations + ) + self.using_grpc = True + log.info("Using Pinecone gRPC client for optimal performance") + else: + # Fallback to HTTP client with enhanced connection pooling + self.client = Pinecone( + api_key=self.api_key, + pool_threads=20, # Improved connection pool size + timeout=30, # Reasonable timeout for operations + ) + self.using_grpc = False + log.info("Using Pinecone HTTP client (gRPC not available)") # Persistent executor for batch operations self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5) @@ -91,12 +117,53 @@ class PineconeClient(VectorDBBase): log.info(f"Using existing Pinecone index '{self.index_name}'") # Connect to the index - self.index = self.client.Index(self.index_name) + self.index = self.client.Index( + self.index_name, + pool_threads=20, # Enhanced connection pool for index operations + ) except Exception as e: log.error(f"Failed to initialize Pinecone index: {e}") raise RuntimeError(f"Failed to initialize Pinecone index: {e}") + def _retry_pinecone_operation(self, operation_func, max_retries=3): + """Retry Pinecone operations with exponential backoff for rate limits and network issues.""" + for attempt in range(max_retries): + try: + return operation_func() + except Exception as e: + error_str = str(e).lower() + # Check if it's a retryable error (rate limits, network issues, timeouts) + is_retryable = any( + keyword in error_str + for keyword in [ + "rate limit", + "quota", + "timeout", + "network", + "connection", + "unavailable", + "internal error", + "429", + "500", + "502", + "503", + "504", + ] + ) + + if not is_retryable or attempt == max_retries - 1: + # Don't retry for non-retryable errors or on final attempt + raise + + # Exponential backoff with jitter + delay = (2**attempt) + random.uniform(0, 1) + log.warning( + f"Pinecone operation failed (attempt {attempt + 1}/{max_retries}), " + f"retrying in {delay:.2f}s: {e}" + ) + time.sleep(delay) + def _create_points( self, items: List[VectorItem], collection_name_with_prefix: str ) -> List[Dict[str, Any]]: @@ -223,7 +290,8 @@ class PineconeClient(VectorDBBase): elapsed = time.time() - start_time log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds") log.info( - f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'" + f"Successfully inserted {len(points)} vectors in parallel batches " + f"into '{collection_name_with_prefix}'" ) def upsert(self, collection_name: str, items: List[VectorItem]) -> None: @@ -254,7 +322,8 @@ class PineconeClient(VectorDBBase): elapsed = time.time() - start_time log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds") log.info( - f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'" + f"Successfully upserted {len(points)} vectors in parallel batches " + f"into '{collection_name_with_prefix}'" ) async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None: @@ -285,7 +354,8 @@ class PineconeClient(VectorDBBase): log.error(f"Error in async insert batch: {result}") raise result log.info( - f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'" + f"Successfully async inserted {len(points)} vectors in batches " + f"into '{collection_name_with_prefix}'" ) async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None: @@ -316,7 +386,8 @@ class PineconeClient(VectorDBBase): log.error(f"Error in async upsert batch: {result}") raise result log.info( - f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'" + f"Successfully async upserted {len(points)} vectors in batches " + f"into '{collection_name_with_prefix}'" ) def search( @@ -457,10 +528,12 @@ class PineconeClient(VectorDBBase): # This is a limitation of Pinecone - be careful with ID uniqueness self.index.delete(ids=batch_ids) log.debug( - f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'" + f"Deleted batch of {len(batch_ids)} vectors by ID " + f"from '{collection_name_with_prefix}'" ) log.info( - f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'" + f"Successfully deleted {len(ids)} vectors by ID " + f"from '{collection_name_with_prefix}'" ) elif filter: diff --git a/backend/open_webui/retrieval/web/perplexity.py b/backend/open_webui/retrieval/web/perplexity.py index e5314eb1f7..4e046668fa 100644 --- a/backend/open_webui/retrieval/web/perplexity.py +++ b/backend/open_webui/retrieval/web/perplexity.py @@ -1,10 +1,20 @@ import logging -from typing import Optional, List +from typing import Optional, Literal import requests from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS +MODELS = Literal[ + "sonar", + "sonar-pro", + "sonar-reasoning", + "sonar-reasoning-pro", + "sonar-deep-research", +] +SEARCH_CONTEXT_USAGE_LEVELS = Literal["low", "medium", "high"] + + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -14,6 +24,8 @@ def search_perplexity( query: str, count: int, filter_list: Optional[list[str]] = None, + model: MODELS = "sonar", + search_context_usage: SEARCH_CONTEXT_USAGE_LEVELS = "medium", ) -> list[SearchResult]: """Search using Perplexity API and return the results as a list of SearchResult objects. @@ -21,6 +33,9 @@ def search_perplexity( api_key (str): A Perplexity API key query (str): The query to search for count (int): Maximum number of results to return + filter_list (Optional[list[str]]): List of domains to filter results + model (str): The Perplexity model to use (sonar, sonar-pro) + search_context_usage (str): Search context usage level (low, medium, high) """ @@ -33,7 +48,7 @@ def search_perplexity( # Create payload for the API call payload = { - "model": "sonar", + "model": model, "messages": [ { "role": "system", @@ -43,6 +58,9 @@ def search_perplexity( ], "temperature": 0.2, # Lower temperature for more factual responses "stream": False, + "web_search_options": { + "search_context_usage": search_context_usage, + }, } headers = { diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py index 5ad5ff051e..94f8325d70 100644 --- a/backend/open_webui/routers/notes.py +++ b/backend/open_webui/routers/notes.py @@ -124,9 +124,8 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if ( - user.role != "admin" - and user.id != note.user_id + if user.role != "admin" or ( + user.id != note.user_id and not has_access(user.id, type="read", access_control=note.access_control) ): raise HTTPException( @@ -159,9 +158,8 @@ async def update_note_by_id( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if ( - user.role != "admin" - and user.id != note.user_id + if user.role != "admin" or ( + user.id != note.user_id and not has_access(user.id, type="write", access_control=note.access_control) ): raise HTTPException( @@ -199,9 +197,8 @@ async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if ( - user.role != "admin" - and user.id != note.user_id + if user.role != "admin" or ( + user.id != note.user_id and not has_access(user.id, type="write", access_control=note.access_control) ): raise HTTPException( diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 9c3c393677..7649271fee 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -887,6 +887,88 @@ async def generate_chat_completion( await session.close() +async def embeddings(request: Request, form_data: dict, user): + """ + Calls the embeddings endpoint for OpenAI-compatible providers. + + Args: + request (Request): The FastAPI request context. + form_data (dict): OpenAI-compatible embeddings payload. + user (UserModel): The authenticated user. + + Returns: + dict: OpenAI-compatible embeddings response. + """ + idx = 0 + # Prepare payload/body + body = json.dumps(form_data) + # Find correct backend url/key based on model + await get_all_models(request, user=user) + model_id = form_data.get("model") + models = request.app.state.OPENAI_MODELS + if model_id in models: + idx = models[model_id]["urlIdx"] + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + key = request.app.state.config.OPENAI_API_KEYS[idx] + r = None + session = None + streaming = False + try: + session = aiohttp.ClientSession(trust_env=True) + r = await session.request( + method="POST", + url=f"{url}/embeddings", + data=body, + headers={ + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, + ) + r.raise_for_status() + if "text/event-stream" in r.headers.get("Content-Type", ""): + streaming = True + return StreamingResponse( + r.content, + status_code=r.status, + headers=dict(r.headers), + background=BackgroundTask( + cleanup_response, response=r, session=session + ), + ) + else: + response_data = await r.json() + return response_data + except Exception as e: + log.exception(e) + detail = None + if r is not None: + try: + res = await r.json() + if "error" in res: + detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" + except Exception: + detail = f"External: {e}" + raise HTTPException( + status_code=r.status if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) + finally: + if not streaming and session: + if r: + r.close() + await session.close() + + @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def proxy(path: str, request: Request, user=Depends(get_verified_user)): """ diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 343b0513c9..22b264bfad 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -467,6 +467,8 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, "EXA_API_KEY": request.app.state.config.EXA_API_KEY, "PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY, + "PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL, + "PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE, "SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID, "SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK, "WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE, @@ -520,6 +522,8 @@ class WebConfig(BaseModel): BING_SEARCH_V7_SUBSCRIPTION_KEY: Optional[str] = None EXA_API_KEY: Optional[str] = None PERPLEXITY_API_KEY: Optional[str] = None + PERPLEXITY_MODEL: Optional[str] = None + PERPLEXITY_SEARCH_CONTEXT_USAGE: Optional[str] = None SOUGOU_API_SID: Optional[str] = None SOUGOU_API_SK: Optional[str] = None WEB_LOADER_ENGINE: Optional[str] = None @@ -907,6 +911,10 @@ async def update_rag_config( ) request.app.state.config.EXA_API_KEY = form_data.web.EXA_API_KEY request.app.state.config.PERPLEXITY_API_KEY = form_data.web.PERPLEXITY_API_KEY + request.app.state.config.PERPLEXITY_MODEL = form_data.web.PERPLEXITY_MODEL + request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = ( + form_data.web.PERPLEXITY_SEARCH_CONTEXT_USAGE + ) request.app.state.config.SOUGOU_API_SID = form_data.web.SOUGOU_API_SID request.app.state.config.SOUGOU_API_SK = form_data.web.SOUGOU_API_SK @@ -1030,6 +1038,8 @@ async def update_rag_config( "BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, "EXA_API_KEY": request.app.state.config.EXA_API_KEY, "PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY, + "PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL, + "PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE, "SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID, "SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK, "WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE, @@ -1740,19 +1750,14 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: request.app.state.config.WEB_SEARCH_RESULT_COUNT, request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) - elif engine == "exa": - return search_exa( - request.app.state.config.EXA_API_KEY, - query, - request.app.state.config.WEB_SEARCH_RESULT_COUNT, - request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, - ) elif engine == "perplexity": return search_perplexity( request.app.state.config.PERPLEXITY_API_KEY, query, request.app.state.config.WEB_SEARCH_RESULT_COUNT, request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, + model=request.app.state.config.PERPLEXITY_MODEL, + search_context_usage=request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE, ) elif engine == "sougou": if ( diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index f94346099e..3832c0306b 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -9,6 +9,7 @@ import re from open_webui.utils.chat import generate_chat_completion from open_webui.utils.task import ( title_generation_template, + follow_up_generation_template, query_generation_template, image_prompt_generation_template, autocomplete_generation_template, @@ -25,6 +26,7 @@ from open_webui.utils.task import get_task_model_id from open_webui.config import ( DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE, + DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE, DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, @@ -58,6 +60,8 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)): "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, + "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, + "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION, "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION, "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, @@ -76,6 +80,8 @@ class TaskConfigForm(BaseModel): ENABLE_AUTOCOMPLETE_GENERATION: bool AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int TAGS_GENERATION_PROMPT_TEMPLATE: str + FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str + ENABLE_FOLLOW_UP_GENERATION: bool ENABLE_TAGS_GENERATION: bool ENABLE_SEARCH_QUERY_GENERATION: bool ENABLE_RETRIEVAL_QUERY_GENERATION: bool @@ -94,6 +100,13 @@ async def update_task_config( form_data.TITLE_GENERATION_PROMPT_TEMPLATE ) + request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = ( + form_data.ENABLE_FOLLOW_UP_GENERATION + ) + request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = ( + form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE + ) + request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = ( form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE ) @@ -133,6 +146,8 @@ async def update_task_config( "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, + "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION, + "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, @@ -231,6 +246,86 @@ async def generate_title( ) +@router.post("/follow_up/completions") +async def generate_follow_ups( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + + if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION: + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"detail": "Follow-up generation is disabled"}, + ) + + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating chat title using model {task_model_id} for user {user.email} " + ) + + if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "": + template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE + else: + template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE + + content = follow_up_generation_template( + template, + form_data["messages"], + { + "name": user.name, + "location": user.info.get("location") if user.info else None, + }, + ) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": { + **(request.state.metadata if hasattr(request.state, "metadata") else {}), + "task": str(TASKS.FOLLOW_UP_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + # Process the payload through the pipeline + try: + payload = await process_pipeline_inlet_filter(request, payload, user, models) + except Exception as e: + raise e + + try: + return await generate_chat_completion(request, form_data=payload, user=user) + except Exception as e: + log.error("Exception occurred", exc_info=True) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": "An internal error has occurred."}, + ) + + @router.post("/tags/completions") async def generate_chat_tags( request: Request, form_data: dict, user=Depends(get_verified_user) diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index 8702ae50ba..4046dc72d8 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -165,22 +165,6 @@ async def update_default_user_permissions( return request.app.state.config.USER_PERMISSIONS -############################ -# UpdateUserRole -############################ - - -@router.post("/update/role", response_model=Optional[UserModel]) -async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)): - if user.id != form_data.id and form_data.id != Users.get_first_user().id: - return Users.update_user_role_by_id(form_data.id, form_data.role) - - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACTION_PROHIBITED, - ) - - ############################ # GetUserSettingsBySessionUser ############################ @@ -333,11 +317,22 @@ async def update_user_by_id( # Prevent modification of the primary admin user by other admins try: first_user = Users.get_first_user() - if first_user and user_id == first_user.id and session_user.id != user_id: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACTION_PROHIBITED, - ) + if first_user: + if user_id == first_user.id: + if session_user.id != user_id: + # If the user trying to update is the primary admin, and they are not the primary admin themselves + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACTION_PROHIBITED, + ) + + if form_data.role != "admin": + # If the primary admin is trying to change their own role, prevent it + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACTION_PROHIBITED, + ) + except Exception as e: log.error(f"Error checking primary admin status: {e}") raise HTTPException( @@ -365,6 +360,7 @@ async def update_user_by_id( updated_user = Users.update_user_by_id( user_id, { + "role": form_data.role, "name": form_data.name, "email": form_data.email.lower(), "profile_image_url": form_data.profile_image_url, diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 4bd744e3c3..268c910e3e 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -320,12 +320,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any): extra_params = { "__event_emitter__": get_event_emitter(metadata), "__event_call__": get_event_call(metadata), - "__user__": { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, + "__user__": user.model_dump() if isinstance(user, UserModel) else {}, "__metadata__": metadata, "__request__": request, "__model__": model, @@ -424,12 +419,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A params[key] = value if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } + __user__ = (user.model_dump() if isinstance(user, UserModel) else {},) try: if hasattr(function_module, "UserValves"): diff --git a/backend/open_webui/utils/embeddings.py b/backend/open_webui/utils/embeddings.py new file mode 100644 index 0000000000..49ce72c3c5 --- /dev/null +++ b/backend/open_webui/utils/embeddings.py @@ -0,0 +1,90 @@ +import random +import logging +import sys + +from fastapi import Request +from open_webui.models.users import UserModel +from open_webui.models.models import Models +from open_webui.utils.models import check_model_access +from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL + +from open_webui.routers.openai import embeddings as openai_embeddings +from open_webui.routers.ollama import ( + embeddings as ollama_embeddings, + GenerateEmbeddingsForm, +) + + +from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama +from open_webui.utils.response import convert_embedding_response_ollama_to_openai + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + + +async def generate_embeddings( + request: Request, + form_data: dict, + user: UserModel, + bypass_filter: bool = False, +): + """ + Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama). + + Args: + request (Request): The FastAPI request context. + form_data (dict): The input data sent to the endpoint. + user (UserModel): The authenticated user. + bypass_filter (bool): If True, disables access filtering (default False). + + Returns: + dict: The embeddings response, following OpenAI API compatibility. + """ + if BYPASS_MODEL_ACCESS_CONTROL: + bypass_filter = True + + # Attach extra metadata from request.state if present + if hasattr(request.state, "metadata"): + if "metadata" not in form_data: + form_data["metadata"] = request.state.metadata + else: + form_data["metadata"] = { + **form_data["metadata"], + **request.state.metadata, + } + + # If "direct" flag present, use only that model + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS + + model_id = form_data.get("model") + if model_id not in models: + raise Exception("Model not found") + model = models[model_id] + + # Access filtering + if not getattr(request.state, "direct", False): + if not bypass_filter and user.role == "user": + check_model_access(user, model) + + # Ollama backend + if model.get("owned_by") == "ollama": + ollama_payload = convert_embedding_payload_openai_to_ollama(form_data) + response = await ollama_embeddings( + request=request, + form_data=GenerateEmbeddingsForm(**ollama_payload), + user=user, + ) + return convert_embedding_response_ollama_to_openai(response) + + # Default: OpenAI or compatible backend + return await openai_embeddings( + request=request, + form_data=form_data, + user=user, + ) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 7b5659d514..6510d7c99b 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -32,6 +32,7 @@ from open_webui.socket.main import ( from open_webui.routers.tasks import ( generate_queries, generate_title, + generate_follow_ups, generate_image_prompt, generate_chat_tags, ) @@ -726,12 +727,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): extra_params = { "__event_emitter__": event_emitter, "__event_call__": event_call, - "__user__": { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, + "__user__": user.model_dump() if isinstance(user, UserModel) else {}, "__metadata__": metadata, "__request__": request, "__model__": model, @@ -1048,6 +1044,59 @@ async def process_chat_response( ) if tasks and messages: + if ( + TASKS.FOLLOW_UP_GENERATION in tasks + and tasks[TASKS.FOLLOW_UP_GENERATION] + ): + res = await generate_follow_ups( + request, + { + "model": message["model"], + "messages": messages, + "message_id": metadata["message_id"], + "chat_id": metadata["chat_id"], + }, + user, + ) + + if res and isinstance(res, dict): + if len(res.get("choices", [])) == 1: + follow_ups_string = ( + res.get("choices", [])[0] + .get("message", {}) + .get("content", "") + ) + else: + follow_ups_string = "" + + follow_ups_string = follow_ups_string[ + follow_ups_string.find("{") : follow_ups_string.rfind("}") + + 1 + ] + + try: + follow_ups = json.loads(follow_ups_string).get( + "follow_ups", [] + ) + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "followUps": follow_ups, + }, + ) + + await event_emitter( + { + "type": "chat:message:follow_ups", + "data": { + "follow_ups": follow_ups, + }, + } + ) + except Exception as e: + pass + if TASKS.TITLE_GENERATION in tasks: if tasks[TASKS.TITLE_GENERATION]: res = await generate_title( @@ -1273,12 +1322,7 @@ async def process_chat_response( extra_params = { "__event_emitter__": event_emitter, "__event_call__": event_caller, - "__user__": { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, + "__user__": user.model_dump() if isinstance(user, UserModel) else {}, "__metadata__": metadata, "__request__": request, "__model__": model, diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index de33558596..6c98ed7dfa 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -538,7 +538,7 @@ class OAuthManager: # Redirect back to the frontend with the JWT token redirect_base_url = request.app.state.config.WEBUI_URL or request.base_url - if redirect_base_url.endswith("/"): + if isinstance(redirect_base_url, str) and redirect_base_url.endswith("/"): redirect_base_url = redirect_base_url[:-1] redirect_url = f"{redirect_base_url}/auth#token={jwt_token}" diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index 02eb0da22b..8bf705ecf0 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -329,3 +329,32 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict: ollama_payload["format"] = format return ollama_payload + + +def convert_embedding_payload_openai_to_ollama(openai_payload: dict) -> dict: + """ + Convert an embeddings request payload from OpenAI format to Ollama format. + + Args: + openai_payload (dict): The original payload designed for OpenAI API usage. + + Returns: + dict: A payload compatible with the Ollama API embeddings endpoint. + """ + ollama_payload = {"model": openai_payload.get("model")} + input_value = openai_payload.get("input") + + # Ollama expects 'input' as a list, and 'prompt' as a single string. + if isinstance(input_value, list): + ollama_payload["input"] = input_value + ollama_payload["prompt"] = "\n".join(str(x) for x in input_value) + else: + ollama_payload["input"] = [input_value] + ollama_payload["prompt"] = str(input_value) + + # Optionally forward other fields if present + for optional_key in ("options", "truncate", "keep_alive"): + if optional_key in openai_payload: + ollama_payload[optional_key] = openai_payload[optional_key] + + return ollama_payload diff --git a/backend/open_webui/utils/response.py b/backend/open_webui/utils/response.py index 8c3f1a58eb..f71087e4ff 100644 --- a/backend/open_webui/utils/response.py +++ b/backend/open_webui/utils/response.py @@ -125,3 +125,64 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response) yield line yield "data: [DONE]\n\n" + + +def convert_embedding_response_ollama_to_openai(response) -> dict: + """ + Convert the response from Ollama embeddings endpoint to the OpenAI-compatible format. + + Args: + response (dict): The response from the Ollama API, + e.g. {"embedding": [...], "model": "..."} + or {"embeddings": [{"embedding": [...], "index": 0}, ...], "model": "..."} + + Returns: + dict: Response adapted to OpenAI's embeddings API format. + e.g. { + "object": "list", + "data": [ + {"object": "embedding", "embedding": [...], "index": 0}, + ... + ], + "model": "...", + } + """ + # Ollama batch-style output + if isinstance(response, dict) and "embeddings" in response: + openai_data = [] + for i, emb in enumerate(response["embeddings"]): + openai_data.append( + { + "object": "embedding", + "embedding": emb.get("embedding"), + "index": emb.get("index", i), + } + ) + return { + "object": "list", + "data": openai_data, + "model": response.get("model"), + } + # Ollama single output + elif isinstance(response, dict) and "embedding" in response: + return { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": response["embedding"], + "index": 0, + } + ], + "model": response.get("model"), + } + # Already OpenAI-compatible? + elif ( + isinstance(response, dict) + and "data" in response + and isinstance(response["data"], list) + ): + return response + + # Fallback: return as is if unrecognized + return response diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 95018eef18..42b44d5167 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -207,6 +207,24 @@ def title_generation_template( return template +def follow_up_generation_template( + template: str, messages: list[dict], user: Optional[dict] = None +) -> str: + prompt = get_last_user_message(messages) + template = replace_prompt_variable(template, prompt) + template = replace_messages_variable(template, messages) + + template = prompt_template( + template, + **( + {"user_name": user.get("name"), "user_location": user.get("location")} + if user + else {} + ), + ) + return template + + def tags_generation_template( template: str, messages: list[dict], user: Optional[dict] = None ) -> str: diff --git a/backend/requirements.txt b/backend/requirements.txt index 9930cd3b68..c714060478 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,5 +1,5 @@ fastapi==0.115.7 -uvicorn[standard]==0.34.0 +uvicorn[standard]==0.34.2 pydantic==2.10.6 python-multipart==0.0.20 @@ -76,13 +76,13 @@ pandas==2.2.3 openpyxl==3.1.5 pyxlsb==1.0.10 xlrd==2.0.1 -validators==0.34.0 +validators==0.35.0 psutil sentencepiece soundfile==0.13.1 -azure-ai-documentintelligence==1.0.0 +azure-ai-documentintelligence==1.0.2 -pillow==11.1.0 +pillow==11.2.1 opencv-python-headless==4.11.0.86 rapidocr-onnxruntime==1.4.4 rank-bm25==0.2.2 diff --git a/package-lock.json b/package-lock.json index ae602efa0e..fbe35065e1 100644 --- a/package-lock.json +++ b/package-lock.json @@ -51,7 +51,7 @@ "idb": "^7.1.1", "js-sha256": "^0.10.1", "jspdf": "^3.0.0", - "katex": "^0.16.21", + "katex": "^0.16.22", "kokoro-js": "^1.1.1", "marked": "^9.1.0", "mermaid": "^11.6.0", @@ -7930,9 +7930,9 @@ } }, "node_modules/katex": { - "version": "0.16.21", - "resolved": "https://registry.npmjs.org/katex/-/katex-0.16.21.tgz", - "integrity": "sha512-XvqR7FgOHtWupfMiigNzmh+MgUVmDGU2kXZm899ZkPfcuoPuFxyHmXsgATDpFZDAXCI8tvinaVcDo8PIIJSo4A==", + "version": "0.16.22", + "resolved": "https://registry.npmjs.org/katex/-/katex-0.16.22.tgz", + "integrity": "sha512-XCHRdUw4lf3SKBaJe4EvgqIuWwkPSo9XoeO8GjQW94Bp7TWv9hNhzZjZ+OH9yf1UmLygb7DIT5GSFQiyt16zYg==", "funding": [ "https://opencollective.com/katex", "https://github.com/sponsors/katex" diff --git a/package.json b/package.json index b2a71845ed..8737788a92 100644 --- a/package.json +++ b/package.json @@ -95,7 +95,7 @@ "idb": "^7.1.1", "js-sha256": "^0.10.1", "jspdf": "^3.0.0", - "katex": "^0.16.21", + "katex": "^0.16.22", "kokoro-js": "^1.1.1", "marked": "^9.1.0", "mermaid": "^11.6.0", diff --git a/src/app.css b/src/app.css index ea0bd5fb0a..352a18d213 100644 --- a/src/app.css +++ b/src/app.css @@ -44,6 +44,10 @@ code { font-family: 'InstrumentSerif', sans-serif; } +.marked a { + @apply underline; +} + math { margin-top: 1rem; } diff --git a/src/lib/apis/auths/index.ts b/src/lib/apis/auths/index.ts index 169a6c14fc..842edd9c9d 100644 --- a/src/lib/apis/auths/index.ts +++ b/src/lib/apis/auths/index.ts @@ -336,7 +336,7 @@ export const userSignOut = async () => { }) .then(async (res) => { if (!res.ok) throw await res.json(); - return res; + return res.json(); }) .catch((err) => { console.error(err); diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 268be397bc..8e4c78aec3 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -612,6 +612,78 @@ export const generateTitle = async ( } }; +export const generateFollowUps = async ( + token: string = '', + model: string, + messages: string, + chat_id?: string +) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/follow_ups/completions`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + messages: messages, + ...(chat_id && { chat_id: chat_id }) + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.error(err); + if ('detail' in err) { + error = err.detail; + } + return null; + }); + + if (error) { + throw error; + } + + try { + // Step 1: Safely extract the response string + const response = res?.choices[0]?.message?.content ?? ''; + + // Step 2: Attempt to fix common JSON format issues like single quotes + const sanitizedResponse = response.replace(/['‘’`]/g, '"'); // Convert single quotes to double quotes for valid JSON + + // Step 3: Find the relevant JSON block within the response + const jsonStartIndex = sanitizedResponse.indexOf('{'); + const jsonEndIndex = sanitizedResponse.lastIndexOf('}'); + + // Step 4: Check if we found a valid JSON block (with both `{` and `}`) + if (jsonStartIndex !== -1 && jsonEndIndex !== -1) { + const jsonResponse = sanitizedResponse.substring(jsonStartIndex, jsonEndIndex + 1); + + // Step 5: Parse the JSON block + const parsed = JSON.parse(jsonResponse); + + // Step 6: If there's a "follow_ups" key, return the follow_ups array; otherwise, return an empty array + if (parsed && parsed.follow_ups) { + return Array.isArray(parsed.follow_ups) ? parsed.follow_ups : []; + } else { + return []; + } + } + + // If no valid JSON block found, return an empty array + return []; + } catch (e) { + // Catch and safely return empty array on any parsing errors + console.error('Failed to parse response: ', e); + return []; + } +}; + export const generateTags = async ( token: string = '', model: string, diff --git a/src/lib/apis/users/index.ts b/src/lib/apis/users/index.ts index f8ab88ff53..391bdca56d 100644 --- a/src/lib/apis/users/index.ts +++ b/src/lib/apis/users/index.ts @@ -393,6 +393,7 @@ export const updateUserById = async (token: string, userId: string, user: UserUp }, body: JSON.stringify({ profile_image_url: user.profile_image_url, + role: user.role, email: user.email, name: user.name, password: user.password !== '' ? user.password : undefined diff --git a/src/lib/components/AddConnectionModal.svelte b/src/lib/components/AddConnectionModal.svelte index 8a708d4d2e..2104d8f939 100644 --- a/src/lib/components/AddConnectionModal.svelte +++ b/src/lib/components/AddConnectionModal.svelte @@ -49,6 +49,9 @@ let loading = false; const verifyOllamaHandler = async () => { + // remove trailing slash from url + url = url.replace(/\/$/, ''); + const res = await verifyOllamaConnection(localStorage.token, { url, key @@ -62,6 +65,9 @@ }; const verifyOpenAIHandler = async () => { + // remove trailing slash from url + url = url.replace(/\/$/, ''); + const res = await verifyOpenAIConnection( localStorage.token, { diff --git a/src/lib/components/admin/Evaluations.svelte b/src/lib/components/admin/Evaluations.svelte index a5532ae2f2..d223db57ce 100644 --- a/src/lib/components/admin/Evaluations.svelte +++ b/src/lib/components/admin/Evaluations.svelte @@ -1,6 +1,8 @@ @@ -37,12 +59,13 @@ class="tabs flex flex-row overflow-x-auto gap-2.5 max-w-full lg:gap-1 lg:flex-col lg:flex-none lg:w-40 dark:text-gray-200 text-sm font-medium text-left scrollbar-none" >