Compare commits

..

No commits in common. "main" and "v0.6.12" have entirely different histories.

733 changed files with 42501 additions and 123761 deletions

View file

@ -7,15 +7,6 @@ OPENAI_API_KEY=''
# AUTOMATIC1111_BASE_URL="http://localhost:7860" # AUTOMATIC1111_BASE_URL="http://localhost:7860"
# For production, you should only need one host as
# fastapi serves the svelte-kit built frontend and backend from the same host and port.
# To test with CORS locally, you can set something like
# CORS_ALLOW_ORIGIN='http://localhost:5173;http://localhost:8080'
CORS_ALLOW_ORIGIN='*'
# For production you should set this to match the proxy configuration (127.0.0.1)
FORWARDED_ALLOW_IPS='*'
# DO NOT TRACK # DO NOT TRACK
SCARF_NO_ANALYTICS=true SCARF_NO_ANALYTICS=true
DO_NOT_TRACK=true DO_NOT_TRACK=true

50
.gitattributes vendored
View file

@ -1,49 +1 @@
# TypeScript *.sh text eol=lf
*.ts text eol=lf
*.tsx text eol=lf
# JavaScript
*.js text eol=lf
*.jsx text eol=lf
*.mjs text eol=lf
*.cjs text eol=lf
# Svelte
*.svelte text eol=lf
# HTML/CSS
*.html text eol=lf
*.css text eol=lf
*.scss text eol=lf
*.less text eol=lf
# Config files and JSON
*.json text eol=lf
*.jsonc text eol=lf
*.yml text eol=lf
*.yaml text eol=lf
*.toml text eol=lf
# Shell scripts
*.sh text eol=lf
# Markdown & docs
*.md text eol=lf
*.mdx text eol=lf
*.txt text eol=lf
# Git-related
.gitattributes text eol=lf
.gitignore text eol=lf
# Prettier and other dotfiles
.prettierrc text eol=lf
.prettierignore text eol=lf
.eslintrc text eol=lf
.eslintignore text eol=lf
.stylelintrc text eol=lf
.editorconfig text eol=lf
# Misc
*.env text eol=lf
*.lock text eol=lf

View file

@ -11,9 +11,7 @@ body:
## Important Notes ## Important Notes
- **Before submitting a bug report**: Please check the [Issues](https://github.com/open-webui/open-webui/issues) and [Discussions](https://github.com/open-webui/open-webui/discussions) sections to see if a similar issue has already been reported. If unsure, start a discussion first, as this helps us efficiently focus on improving the project. Duplicates may be closed without notice. **Please search for existing issues AND discussions. No matter open or closed.** - **Before submitting a bug report**: Please check the [Issues](https://github.com/open-webui/open-webui/issues) or [Discussions](https://github.com/open-webui/open-webui/discussions) sections to see if a similar issue has already been reported. If unsure, start a discussion first, as this helps us efficiently focus on improving the project.
- Check for opened, **but also for (recently) CLOSED issues** as the issue you are trying to report **might already have been fixed on the dev branch!**
- **Respectful collaboration**: Open WebUI is a volunteer-driven project with a single maintainer and contributors who also have full-time jobs. Please be constructive and respectful in your communication. - **Respectful collaboration**: Open WebUI is a volunteer-driven project with a single maintainer and contributors who also have full-time jobs. Please be constructive and respectful in your communication.
@ -21,19 +19,13 @@ body:
- **Bug Reproducibility**: If a bug cannot be reproduced using a `:main` or `:dev` Docker setup or with `pip install` on Python 3.11, community assistance may be required. In such cases, we will move it to the "[Issues](https://github.com/open-webui/open-webui/discussions/categories/issues)" Discussions section. Your help is appreciated! - **Bug Reproducibility**: If a bug cannot be reproduced using a `:main` or `:dev` Docker setup or with `pip install` on Python 3.11, community assistance may be required. In such cases, we will move it to the "[Issues](https://github.com/open-webui/open-webui/discussions/categories/issues)" Discussions section. Your help is appreciated!
- **Scope**: If you want to report a SECURITY VULNERABILITY, then do so through our [GitHub security page](https://github.com/open-webui/open-webui/security).
- type: checkboxes - type: checkboxes
id: issue-check id: issue-check
attributes: attributes:
label: Check Existing Issues label: Check Existing Issues
description: Confirm that youve checked for existing reports before submitting a new one. description: Confirm that youve checked for existing reports before submitting a new one.
options: options:
- label: I have searched for any existing and/or related issues. - label: I have searched the existing issues and discussions.
required: true
- label: I have searched for any existing and/or related discussions.
required: true
- label: I have also searched in the CLOSED issues AND CLOSED discussions and found no related items (your issue might already be addressed on the development branch!).
required: true required: true
- label: I am using the latest version of Open WebUI. - label: I am using the latest version of Open WebUI.
required: true required: true
@ -55,7 +47,7 @@ body:
id: open-webui-version id: open-webui-version
attributes: attributes:
label: Open WebUI Version label: Open WebUI Version
description: Specify the version (e.g., v0.6.26) description: Specify the version (e.g., v0.3.11)
validations: validations:
required: true required: true
@ -71,7 +63,7 @@ body:
id: operating-system id: operating-system
attributes: attributes:
label: Operating System label: Operating System
description: Specify the OS (e.g., Windows 10, macOS Sonoma, Ubuntu 22.04, Debian 12) description: Specify the OS (e.g., Windows 10, macOS Sonoma, Ubuntu 22.04)
validations: validations:
required: true required: true
@ -134,7 +126,6 @@ body:
description: | description: |
Please provide a **very detailed, step-by-step guide** to reproduce the issue. Your instructions should be so clear and precise that anyone can follow them without guesswork. Include every relevant detail—settings, configuration options, exact commands used, values entered, and any prerequisites or environment variables. Please provide a **very detailed, step-by-step guide** to reproduce the issue. Your instructions should be so clear and precise that anyone can follow them without guesswork. Include every relevant detail—settings, configuration options, exact commands used, values entered, and any prerequisites or environment variables.
**If full reproduction steps and all relevant settings are not provided, your issue may not be addressed.** **If full reproduction steps and all relevant settings are not provided, your issue may not be addressed.**
**If your steps to reproduction are incomplete, lacking detail or not reproducible, your issue can not be addressed.**
placeholder: | placeholder: |
Example (include every detail): Example (include every detail):
@ -172,5 +163,5 @@ body:
attributes: attributes:
value: | value: |
## Note ## Note
**If the bug report is incomplete, does not follow instructions or is lacking details it may not be addressed.** Ensure that you've followed all the **README.md** and **troubleshooting.md** guidelines, and provide all necessary information for us to reproduce the issue. If the bug report is incomplete or does not follow instructions, it may not be addressed. Ensure that you've followed all the **README.md** and **troubleshooting.md** guidelines, and provide all necessary information for us to reproduce the issue.
Thank you for contributing to Open WebUI! Thank you for contributing to Open WebUI!

View file

@ -8,21 +8,10 @@ body:
value: | value: |
## Important Notes ## Important Notes
### Before submitting ### Before submitting
Please check the [Issues](https://github.com/open-webui/open-webui/issues) or [Discussions](https://github.com/open-webui/open-webui/discussions) to see if a similar request has been posted.
Please check the **open AND closed** [Issues](https://github.com/open-webui/open-webui/issues) AND [Discussions](https://github.com/open-webui/open-webui/discussions) to see if a similar request has been posted.
It's likely we're already tracking it! If youre unsure, start a discussion post first. It's likely we're already tracking it! If youre unsure, start a discussion post first.
#### Scope
If your feature request is likely to take more than a quick coding session to implement, test and verify, then open it in the **Ideas** section of the [Discussions](https://github.com/open-webui/open-webui/discussions) instead.
**We will close and force move your feature request to the Ideas section, if we believe your feature request is not trivial/quick to implement.**
This is to ensure the issues tab is used only for issues, quickly addressable feature requests and tracking tickets by the maintainers.
Other feature requests belong in the **Ideas** section of the [Discussions](https://github.com/open-webui/open-webui/discussions).
If your feature request might impact others in the community, definitely open a discussion instead and evaluate whether and how to implement it.
This will help us efficiently focus on improving the project. This will help us efficiently focus on improving the project.
### Collaborate respectfully ### Collaborate respectfully
We value a **constructive attitude**, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. We're here to help if you're **open to learning** and **communicating positively**. We value a **constructive attitude**, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. We're here to help if you're **open to learning** and **communicating positively**.
@ -33,6 +22,7 @@ body:
We appreciate your time and ask that you **respect ours**. We appreciate your time and ask that you **respect ours**.
### Contributing ### Contributing
If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI. If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
@ -45,22 +35,14 @@ body:
label: Check Existing Issues label: Check Existing Issues
description: Please confirm that you've checked for existing similar requests description: Please confirm that you've checked for existing similar requests
options: options:
- label: I have searched for all existing **open AND closed** issues and discussions for similar requests. I have found none that is comparable to my request. - label: I have searched the existing issues and discussions.
required: true
- type: checkboxes
id: feature-scope
attributes:
label: Verify Feature Scope
description: Please confirm the feature's scope is within the described scope
options:
- label: I have read through and understood the scope definition for feature requests in the Issues section. I believe my feature request meets the definition and belongs in the Issues section instead of the Discussions.
required: true required: true
- type: textarea - type: textarea
id: problem-description id: problem-description
attributes: attributes:
label: Problem Description label: Problem Description
description: Is your feature request related to a problem? Please provide a clear and concise description of what the problem is. description: Is your feature request related to a problem? Please provide a clear and concise description of what the problem is.
placeholder: "Ex. I'm always frustrated when... / Not related to a problem" placeholder: "Ex. I'm always frustrated when..."
validations: validations:
required: true required: true
- type: textarea - type: textarea

View file

@ -12,6 +12,12 @@ updates:
interval: monthly interval: monthly
target-branch: 'dev' target-branch: 'dev'
- package-ecosystem: npm
directory: '/'
schedule:
interval: monthly
target-branch: 'dev'
- package-ecosystem: 'github-actions' - package-ecosystem: 'github-actions'
directory: '/' directory: '/'
schedule: schedule:

View file

@ -1,20 +1,17 @@
# Pull Request Checklist # Pull Request Checklist
### Note to first-time contributors: Please open a discussion post in [Discussions](https://github.com/open-webui/open-webui/discussions) to discuss your idea/fix with the community before creating a pull request, and describe your changes before submitting a pull request. ### Note to first-time contributors: Please open a discussion post in [Discussions](https://github.com/open-webui/open-webui/discussions) and describe your changes before submitting a pull request.
This is to ensure large feature PRs are discussed with the community first, before starting work on it. If the community does not want this feature or it is not relevant for Open WebUI as a project, it can be identified in the discussion before working on the feature and submitting the PR.
**Before submitting, make sure you've checked the following:** **Before submitting, make sure you've checked the following:**
- [ ] **Target branch:** Verify that the pull request targets the `dev` branch. **Not targeting the `dev` branch will lead to immediate closure of the PR.** - [ ] **Target branch:** Please verify that the pull request targets the `dev` branch.
- [ ] **Description:** Provide a concise description of the changes made in this pull request down below. - [ ] **Description:** Provide a concise description of the changes made in this pull request.
- [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description. - [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description.
- [ ] **Documentation:** If necessary, update relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs) like environment variables, the tutorials, or other documentation sources. - [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources?
- [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation? - [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation?
- [ ] **Testing:** Perform manual tests to **verify the implemented fix/feature works as intended AND does not break any other functionality**. Take this as an opportunity to **make screenshots of the feature/fix and include it in the PR description**. - [ ] **Testing:** Have you written and run sufficient tests to validate the changes?
- [ ] **Agentic AI Code:** Confirm this Pull Request is **not written by any AI Agent** or has at least **gone through additional human review AND manual testing**. If any AI Agent is the co-author of this PR, it may lead to immediate closure of the PR.
- [ ] **Code review:** Have you performed a self-review of your code, addressing any coding standard issues and ensuring adherence to the project's coding standards? - [ ] **Code review:** Have you performed a self-review of your code, addressing any coding standard issues and ensuring adherence to the project's coding standards?
- [ ] **Title Prefix:** To clearly categorize this pull request, prefix the pull request title using one of the following: - [ ] **Prefix:** To clearly categorize this pull request, prefix the pull request title using one of the following:
- **BREAKING CHANGE**: Significant changes that may affect compatibility - **BREAKING CHANGE**: Significant changes that may affect compatibility
- **build**: Changes that affect the build system or external dependencies - **build**: Changes that affect the build system or external dependencies
- **ci**: Changes to our continuous integration processes or workflows - **ci**: Changes to our continuous integration processes or workflows
@ -76,7 +73,4 @@ This is to ensure large feature PRs are discussed with the community first, befo
### Contributor License Agreement ### Contributor License Agreement
By submitting this pull request, I confirm that I have read and fully agree to the [Contributor License Agreement (CLA)](https://github.com/open-webui/open-webui/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT), and I am providing my contributions under its terms. By submitting this pull request, I confirm that I have read and fully agree to the [Contributor License Agreement (CLA)](/CONTRIBUTOR_LICENSE_AGREEMENT), and I am providing my contributions under its terms.
> [!NOTE]
> Deleting the CLA section will lead to immediate closure of your PR and it will not be merged in.

View file

@ -11,7 +11,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v5 uses: actions/checkout@v4
- name: Check for changes in package.json - name: Check for changes in package.json
run: | run: |
@ -36,7 +36,7 @@ jobs:
echo "::set-output name=content::$CHANGELOG_ESCAPED" echo "::set-output name=content::$CHANGELOG_ESCAPED"
- name: Create GitHub release - name: Create GitHub release
uses: actions/github-script@v8 uses: actions/github-script@v7
with: with:
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
script: | script: |
@ -61,7 +61,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Trigger Docker build workflow - name: Trigger Docker build workflow
uses: actions/github-script@v8 uses: actions/github-script@v7
with: with:
script: | script: |
github.rest.actions.createWorkflowDispatch({ github.rest.actions.createWorkflowDispatch({

View file

@ -27,7 +27,7 @@ jobs:
HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_TOKEN: ${{ secrets.HF_TOKEN }}
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v5 uses: actions/checkout@v4
with: with:
lfs: true lfs: true
@ -57,8 +57,7 @@ jobs:
git lfs install git lfs install
git lfs track "*.ttf" git lfs track "*.ttf"
git lfs track "*.jpg" git lfs track "*.jpg"
rm demo.png rm demo.gif
rm banner.png
git add . git add .
git commit -m "GitHub deploy: ${{ github.sha }}" git commit -m "GitHub deploy: ${{ github.sha }}"
git push --force https://open-webui:${HF_TOKEN}@huggingface.co/spaces/open-webui/open-webui main git push --force https://open-webui:${HF_TOKEN}@huggingface.co/spaces/open-webui/open-webui main

View file

@ -14,18 +14,16 @@ env:
jobs: jobs:
build-main-image: build-main-image:
runs-on: ${{ matrix.runner }} runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
permissions: permissions:
contents: read contents: read
packages: write packages: write
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
include: platform:
- platform: linux/amd64 - linux/amd64
runner: ubuntu-latest - linux/arm64
- platform: linux/arm64
runner: ubuntu-24.04-arm
steps: steps:
# GitHub Packages requires the entire repository name to be in lowercase # GitHub Packages requires the entire repository name to be in lowercase
@ -43,7 +41,7 @@ jobs:
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v5 uses: actions/checkout@v4
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v3 uses: docker/setup-qemu-action@v3
@ -113,18 +111,16 @@ jobs:
retention-days: 1 retention-days: 1
build-cuda-image: build-cuda-image:
runs-on: ${{ matrix.runner }} runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
permissions: permissions:
contents: read contents: read
packages: write packages: write
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
include: platform:
- platform: linux/amd64 - linux/amd64
runner: ubuntu-latest - linux/arm64
- platform: linux/arm64
runner: ubuntu-24.04-arm
steps: steps:
# GitHub Packages requires the entire repository name to be in lowercase # GitHub Packages requires the entire repository name to be in lowercase
@ -141,11 +137,8 @@ jobs:
platform=${{ matrix.platform }} platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Delete huge unnecessary tools folder
run: rm -rf /opt/hostedtoolcache
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v5 uses: actions/checkout@v4
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v3 uses: docker/setup-qemu-action@v3
@ -217,125 +210,17 @@ jobs:
if-no-files-found: error if-no-files-found: error
retention-days: 1 retention-days: 1
build-cuda126-image:
runs-on: ${{ matrix.runner }}
permissions:
contents: read
packages: write
strategy:
fail-fast: false
matrix:
include:
- platform: linux/amd64
runner: ubuntu-latest
- platform: linux/arm64
runner: ubuntu-24.04-arm
steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Delete huge unnecessary tools folder
run: rm -rf /opt/hostedtoolcache
- name: Checkout repository
uses: actions/checkout@v5
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to the Container registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata for Docker images (cuda126 tag)
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=tag
type=sha,prefix=git-
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=cuda126
flavor: |
latest=${{ github.ref == 'refs/heads/main' }}
suffix=-cuda126,onlatest=true
- name: Extract metadata for Docker cache
id: cache-meta
uses: docker/metadata-action@v5
with:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
${{ github.ref_type == 'tag' && 'type=raw,value=main' || '' }}
flavor: |
prefix=cache-cuda126-${{ matrix.platform }}-
latest=false
- name: Build Docker image (cuda126)
uses: docker/build-push-action@v5
id: build
with:
context: .
push: true
platforms: ${{ matrix.platform }}
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }}
cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max
build-args: |
BUILD_HASH=${{ github.sha }}
USE_CUDA=true
USE_CUDA_VER=cu126
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
name: digests-cuda126-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1
build-ollama-image: build-ollama-image:
runs-on: ${{ matrix.runner }} runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
permissions: permissions:
contents: read contents: read
packages: write packages: write
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
include: platform:
- platform: linux/amd64 - linux/amd64
runner: ubuntu-latest - linux/arm64
- platform: linux/arm64
runner: ubuntu-24.04-arm
steps: steps:
# GitHub Packages requires the entire repository name to be in lowercase # GitHub Packages requires the entire repository name to be in lowercase
@ -353,7 +238,7 @@ jobs:
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v5 uses: actions/checkout@v4
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v3 uses: docker/setup-qemu-action@v3
@ -425,108 +310,6 @@ jobs:
if-no-files-found: error if-no-files-found: error
retention-days: 1 retention-days: 1
build-slim-image:
runs-on: ${{ matrix.runner }}
permissions:
contents: read
packages: write
strategy:
fail-fast: false
matrix:
include:
- platform: linux/amd64
runner: ubuntu-latest
- platform: linux/arm64
runner: ubuntu-24.04-arm
steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout repository
uses: actions/checkout@v5
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to the Container registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata for Docker images (slim tag)
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=tag
type=sha,prefix=git-
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=slim
flavor: |
latest=${{ github.ref == 'refs/heads/main' }}
suffix=-slim,onlatest=true
- name: Extract metadata for Docker cache
id: cache-meta
uses: docker/metadata-action@v5
with:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
${{ github.ref_type == 'tag' && 'type=raw,value=main' || '' }}
flavor: |
prefix=cache-slim-${{ matrix.platform }}-
latest=false
- name: Build Docker image (slim)
uses: docker/build-push-action@v5
id: build
with:
context: .
push: true
platforms: ${{ matrix.platform }}
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }}
cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max
build-args: |
BUILD_HASH=${{ github.sha }}
USE_SLIM=true
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
name: digests-slim-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1
merge-main-images: merge-main-images:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [build-main-image] needs: [build-main-image]
@ -541,7 +324,7 @@ jobs:
IMAGE_NAME: '${{ github.repository }}' IMAGE_NAME: '${{ github.repository }}'
- name: Download digests - name: Download digests
uses: actions/download-artifact@v5 uses: actions/download-artifact@v4
with: with:
pattern: digests-main-* pattern: digests-main-*
path: /tmp/digests path: /tmp/digests
@ -595,7 +378,7 @@ jobs:
IMAGE_NAME: '${{ github.repository }}' IMAGE_NAME: '${{ github.repository }}'
- name: Download digests - name: Download digests
uses: actions/download-artifact@v5 uses: actions/download-artifact@v4
with: with:
pattern: digests-cuda-* pattern: digests-cuda-*
path: /tmp/digests path: /tmp/digests
@ -637,62 +420,6 @@ jobs:
run: | run: |
docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }} docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
merge-cuda126-images:
runs-on: ubuntu-latest
needs: [build-cuda126-image]
steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Download digests
uses: actions/download-artifact@v5
with:
pattern: digests-cuda126-*
path: /tmp/digests
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to the Container registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata for Docker images (default latest tag)
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=tag
type=sha,prefix=git-
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=cuda126
flavor: |
latest=${{ github.ref == 'refs/heads/main' }}
suffix=-cuda126,onlatest=true
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
merge-ollama-images: merge-ollama-images:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [build-ollama-image] needs: [build-ollama-image]
@ -707,7 +434,7 @@ jobs:
IMAGE_NAME: '${{ github.repository }}' IMAGE_NAME: '${{ github.repository }}'
- name: Download digests - name: Download digests
uses: actions/download-artifact@v5 uses: actions/download-artifact@v4
with: with:
pattern: digests-ollama-* pattern: digests-ollama-*
path: /tmp/digests path: /tmp/digests
@ -748,59 +475,3 @@ jobs:
- name: Inspect image - name: Inspect image
run: | run: |
docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }} docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
merge-slim-images:
runs-on: ubuntu-latest
needs: [build-slim-image]
steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Download digests
uses: actions/download-artifact@v5
with:
pattern: digests-slim-*
path: /tmp/digests
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to the Container registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata for Docker images (default slim tag)
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=tag
type=sha,prefix=git-
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=slim
flavor: |
latest=${{ github.ref == 'refs/heads/main' }}
suffix=-slim,onlatest=true
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}

View file

@ -30,10 +30,10 @@ jobs:
- 3.12.x - 3.12.x
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v6 uses: actions/setup-python@v5
with: with:
python-version: '${{ matrix.python-version }}' python-version: '${{ matrix.python-version }}'

View file

@ -24,15 +24,15 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout Repository - name: Checkout Repository
uses: actions/checkout@v5 uses: actions/checkout@v4
- name: Setup Node.js - name: Setup Node.js
uses: actions/setup-node@v5 uses: actions/setup-node@v4
with: with:
node-version: '22' node-version: '22'
- name: Install Dependencies - name: Install Dependencies
run: npm install --force run: npm install
- name: Format Frontend - name: Format Frontend
run: npm run format run: npm run format
@ -51,15 +51,15 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout Repository - name: Checkout Repository
uses: actions/checkout@v5 uses: actions/checkout@v4
- name: Setup Node.js - name: Setup Node.js
uses: actions/setup-node@v5 uses: actions/setup-node@v4
with: with:
node-version: '22' node-version: '22'
- name: Install Dependencies - name: Install Dependencies
run: npm ci --force run: npm ci
- name: Run vitest - name: Run vitest
run: npm run test:frontend run: npm run test:frontend

View file

@ -16,15 +16,15 @@ jobs:
id-token: write id-token: write
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v5 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Install Git - name: Install Git
run: sudo apt-get update && sudo apt-get install -y git run: sudo apt-get update && sudo apt-get install -y git
- uses: actions/setup-node@v5 - uses: actions/setup-node@v4
with: with:
node-version: 22 node-version: 22
- uses: actions/setup-python@v6 - uses: actions/setup-python@v5
with: with:
python-version: 3.11 python-version: 3.11
- name: Build - name: Build

5
.gitignore vendored
View file

@ -1,5 +1,3 @@
x.py
yarn.lock
.DS_Store .DS_Store
node_modules node_modules
/build /build
@ -14,8 +12,7 @@ vite.config.ts.timestamp-*
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class
.nvmrc
CLAUDE.md
# C extensions # C extensions
*.so *.so

View file

@ -3,6 +3,8 @@ pnpm-lock.yaml
package-lock.json package-lock.json
yarn.lock yarn.lock
kubernetes/
# Copy of .gitignore # Copy of .gitignore
.DS_Store .DS_Store
node_modules node_modules

View file

@ -5,6 +5,5 @@
"printWidth": 100, "printWidth": 100,
"plugins": ["prettier-plugin-svelte"], "plugins": ["prettier-plugin-svelte"],
"pluginSearchDirs": ["."], "pluginSearchDirs": ["."],
"overrides": [{ "files": "*.svelte", "options": { "parser": "svelte" } }], "overrides": [{ "files": "*.svelte", "options": { "parser": "svelte" } }]
"endOfLine": "lf"
} }

File diff suppressed because it is too large Load diff

View file

@ -3,8 +3,6 @@
# use build args in the docker build command with --build-arg="BUILDARG=true" # use build args in the docker build command with --build-arg="BUILDARG=true"
ARG USE_CUDA=false ARG USE_CUDA=false
ARG USE_OLLAMA=false ARG USE_OLLAMA=false
ARG USE_SLIM=false
ARG USE_PERMISSION_HARDENING=false
# Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default) # Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
ARG USE_CUDA_VER=cu128 ARG USE_CUDA_VER=cu128
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
@ -26,16 +24,13 @@ ARG GID=0
FROM --platform=$BUILDPLATFORM node:22-alpine3.20 AS build FROM --platform=$BUILDPLATFORM node:22-alpine3.20 AS build
ARG BUILD_HASH ARG BUILD_HASH
# Set Node.js options (heap limit Allocation failed - JavaScript heap out of memory)
# ENV NODE_OPTIONS="--max-old-space-size=4096"
WORKDIR /app WORKDIR /app
# to store git revision in build # to store git revision in build
RUN apk add --no-cache git RUN apk add --no-cache git
COPY package.json package-lock.json ./ COPY package.json package-lock.json ./
RUN npm ci --force RUN npm ci
COPY . . COPY . .
ENV APP_BUILD_HASH=${BUILD_HASH} ENV APP_BUILD_HASH=${BUILD_HASH}
@ -48,23 +43,17 @@ FROM python:3.11-slim-bookworm AS base
ARG USE_CUDA ARG USE_CUDA
ARG USE_OLLAMA ARG USE_OLLAMA
ARG USE_CUDA_VER ARG USE_CUDA_VER
ARG USE_SLIM
ARG USE_PERMISSION_HARDENING
ARG USE_EMBEDDING_MODEL ARG USE_EMBEDDING_MODEL
ARG USE_RERANKING_MODEL ARG USE_RERANKING_MODEL
ARG UID ARG UID
ARG GID ARG GID
# Python settings
ENV PYTHONUNBUFFERED=1
## Basis ## ## Basis ##
ENV ENV=prod \ ENV ENV=prod \
PORT=8080 \ PORT=8080 \
# pass build args to the build # pass build args to the build
USE_OLLAMA_DOCKER=${USE_OLLAMA} \ USE_OLLAMA_DOCKER=${USE_OLLAMA} \
USE_CUDA_DOCKER=${USE_CUDA} \ USE_CUDA_DOCKER=${USE_CUDA} \
USE_SLIM_DOCKER=${USE_SLIM} \
USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \ USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \ USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \
USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL} USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL}
@ -119,13 +108,29 @@ RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry
# Make sure the user has access to the app and root directory # Make sure the user has access to the app and root directory
RUN chown -R $UID:$GID /app $HOME RUN chown -R $UID:$GID /app $HOME
# Install common system dependencies RUN if [ "$USE_OLLAMA" = "true" ]; then \
RUN apt-get update && \ apt-get update && \
apt-get install -y --no-install-recommends \ # Install pandoc and netcat
git build-essential pandoc gcc netcat-openbsd curl jq \ apt-get install -y --no-install-recommends git build-essential pandoc netcat-openbsd curl && \
python3-dev \ apt-get install -y --no-install-recommends gcc python3-dev && \
ffmpeg libsm6 libxext6 \ # for RAG OCR
&& rm -rf /var/lib/apt/lists/* apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
# install helper tools
apt-get install -y --no-install-recommends curl jq && \
# install ollama
curl -fsSL https://ollama.com/install.sh | sh && \
# cleanup
rm -rf /var/lib/apt/lists/*; \
else \
apt-get update && \
# Install pandoc, netcat and gcc
apt-get install -y --no-install-recommends git build-essential pandoc gcc netcat-openbsd curl jq && \
apt-get install -y --no-install-recommends gcc python3-dev && \
# for RAG OCR
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
# cleanup
rm -rf /var/lib/apt/lists/*; \
fi
# install python dependencies # install python dependencies
COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt
@ -141,22 +146,13 @@ RUN pip3 install --no-cache-dir uv && \
else \ else \
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
uv pip install --system -r requirements.txt --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \
if [ "$USE_SLIM" != "true" ]; then \
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \ python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
fi; \ fi; \
fi; \ chown -R $UID:$GID /app/backend/data/
mkdir -p /app/backend/data && chown -R $UID:$GID /app/backend/data/ && \
rm -rf /var/lib/apt/lists/*;
# Install Ollama if requested
RUN if [ "$USE_OLLAMA" = "true" ]; then \
date +%s > /tmp/ollama_build_hash && \
echo "Cache broken at timestamp: `cat /tmp/ollama_build_hash`" && \
curl -fsSL https://ollama.com/install.sh | sh && \
rm -rf /var/lib/apt/lists/*; \
fi
# copy embedding weight from build # copy embedding weight from build
# RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2 # RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
@ -174,17 +170,6 @@ EXPOSE 8080
HEALTHCHECK CMD curl --silent --fail http://localhost:${PORT:-8080}/health | jq -ne 'input.status == true' || exit 1 HEALTHCHECK CMD curl --silent --fail http://localhost:${PORT:-8080}/health | jq -ne 'input.status == true' || exit 1
# Minimal, atomic permission hardening for OpenShift (arbitrary UID):
# - Group 0 owns /app and /root
# - Directories are group-writable and have SGID so new files inherit GID 0
RUN if [ "$USE_PERMISSION_HARDENING" = "true" ]; then \
set -eux; \
chgrp -R 0 /app /root || true; \
chmod -R g+rwX /app /root || true; \
find /app -type d -exec chmod g+s {} + || true; \
find /root -type d -exec chmod g+s {} + || true; \
fi
USER $UID:$GID USER $UID:$GID
ARG BUILD_HASH ARG BUILD_HASH

35
INSTALLATION.md Normal file
View file

@ -0,0 +1,35 @@
### Installing Both Ollama and Open WebUI Using Kustomize
For cpu-only pod
```bash
kubectl apply -f ./kubernetes/manifest/base
```
For gpu-enabled pod
```bash
kubectl apply -k ./kubernetes/manifest
```
### Installing Both Ollama and Open WebUI Using Helm
Package Helm file first
```bash
helm package ./kubernetes/helm/
```
For cpu-only pod
```bash
helm install ollama-webui ./ollama-webui-*.tgz
```
For gpu-enabled pod
```bash
helm install ollama-webui ./ollama-webui-*.tgz --set ollama.resources.limits.nvidia.com/gpu="1"
```
Check the `kubernetes/helm/values.yaml` file to know which parameters are available for customization

View file

@ -1,4 +1,4 @@
Copyright (c) 2023- Open WebUI Inc. [Created by Timothy Jaeryang Baek] Copyright (c) 2023-2025 Timothy Jaeryang Baek (Open WebUI)
All rights reserved. All rights reserved.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without

View file

@ -1,53 +0,0 @@
All code and materials created before commit `60d84a3aae9802339705826e9095e272e3c83623` are subject to the following copyright and license:
Copyright (c) 2023-2025 Timothy Jaeryang Baek
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
All code and materials created before commit `a76068d69cd59568b920dfab85dc573dbbb8f131` are subject to the following copyright and license:
MIT License
Copyright (c) 2023 Timothy Jaeryang Baek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -1,11 +0,0 @@
# Open WebUI Multi-License Notice
This repository contains code governed by multiple licenses based on the date and origin of contribution:
1. All code committed prior to commit a76068d69cd59568b920dfab85dc573dbbb8f131 is licensed under the MIT License (see LICENSE_HISTORY).
2. All code committed from commit a76068d69cd59568b920dfab85dc573dbbb8f131 up to and including commit 60d84a3aae9802339705826e9095e272e3c83623 is licensed under the BSD 3-Clause License (see LICENSE_HISTORY).
3. All code contributed or modified after commit 60d84a3aae9802339705826e9095e272e3c83623 is licensed under the Open WebUI License (see LICENSE).
For details on which commits are covered by which license, refer to LICENSE_HISTORY.

View file

@ -10,16 +10,12 @@
[![Discord](https://img.shields.io/badge/Discord-Open_WebUI-blue?logo=discord&logoColor=white)](https://discord.gg/5rJgQTnV4s) [![Discord](https://img.shields.io/badge/Discord-Open_WebUI-blue?logo=discord&logoColor=white)](https://discord.gg/5rJgQTnV4s)
[![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/tjbck) [![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/tjbck)
![Open WebUI Banner](./banner.png)
**Open WebUI is an [extensible](https://docs.openwebui.com/features/plugin/), feature-rich, and user-friendly self-hosted AI platform designed to operate entirely offline.** It supports various LLM runners like **Ollama** and **OpenAI-compatible APIs**, with **built-in inference engine** for RAG, making it a **powerful AI deployment solution**. **Open WebUI is an [extensible](https://docs.openwebui.com/features/plugin/), feature-rich, and user-friendly self-hosted AI platform designed to operate entirely offline.** It supports various LLM runners like **Ollama** and **OpenAI-compatible APIs**, with **built-in inference engine** for RAG, making it a **powerful AI deployment solution**.
Passionate about open-source AI? [Join our team →](https://careers.openwebui.com/) ![Open WebUI Demo](./demo.gif)
![Open WebUI Demo](./demo.png)
> [!TIP] > [!TIP]
> **Looking for an [Enterprise Plan](https://docs.openwebui.com/enterprise)?** **[Speak with Our Sales Team Today!](https://docs.openwebui.com/enterprise)** > **Looking for an [Enterprise Plan](https://docs.openwebui.com/enterprise)?** **[Speak with Our Sales Team Today!](mailto:sales@openwebui.com)**
> >
> Get **enhanced capabilities**, including **custom theming and branding**, **Service Level Agreement (SLA) support**, **Long-Term Support (LTS) versions**, and **more!** > Get **enhanced capabilities**, including **custom theming and branding**, **Service Level Agreement (SLA) support**, **Long-Term Support (LTS) versions**, and **more!**
@ -39,38 +35,24 @@ For more information, be sure to check out our [Open WebUI Documentation](https:
- ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction. - ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction.
- 🎤📹 **Hands-Free Voice/Video Call**: Experience seamless communication with integrated hands-free voice and video call features using multiple Speech-to-Text providers (Local Whisper, OpenAI, Deepgram, Azure) and Text-to-Speech engines (Azure, ElevenLabs, OpenAI, Transformers, WebAPI), allowing for dynamic and interactive chat environments. - 🎤📹 **Hands-Free Voice/Video Call**: Experience seamless communication with integrated hands-free voice and video call features, allowing for a more dynamic and interactive chat environment.
- 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration. - 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration.
- 🐍 **Native Python Function Calling Tool**: Enhance your LLMs with built-in code editor support in the tools workspace. Bring Your Own Function (BYOF) by simply adding your pure Python functions, enabling seamless integration with LLMs. - 🐍 **Native Python Function Calling Tool**: Enhance your LLMs with built-in code editor support in the tools workspace. Bring Your Own Function (BYOF) by simply adding your pure Python functions, enabling seamless integration with LLMs.
- 💾 **Persistent Artifact Storage**: Built-in key-value storage API for artifacts, enabling features like journals, trackers, leaderboards, and collaborative tools with both personal and shared data scopes across sessions. - 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support using your choice of 9 vector databases and multiple content extraction engines (Tika, Docling, Document Intelligence, Mistral OCR, External loaders). Load documents directly into chat or add files to your document library, effortlessly accessing them using the `#` command before a query. - 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch`, `SearchApi` and `Bing` and inject the results directly into your chat experience.
- 🔍 **Web Search for RAG**: Perform web searches using 15+ providers including `SearXNG`, `Google PSE`, `Brave Search`, `Kagi`, `Mojeek`, `Tavily`, `Perplexity`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `SearchApi`, `SerpApi`, `Bing`, `Jina`, `Exa`, `Sougou`, `Azure AI Search`, and `Ollama Cloud`, injecting results directly into your chat experience.
- 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions. - 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
- 🎨 **Image Generation & Editing Integration**: Create and edit images using multiple engines including OpenAI's DALL-E, Gemini, ComfyUI (local), and AUTOMATIC1111 (local), with support for both generation and prompt-based editing workflows. - 🎨 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using options such as AUTOMATIC1111 API or ComfyUI (local), and OpenAI's DALL-E (external), enriching your chat experience with dynamic visual content.
- ⚙️ **Many Models Conversations**: Effortlessly engage with various models simultaneously, harnessing their unique strengths for optimal responses. Enhance your experience by leveraging a diverse set of models in parallel. - ⚙️ **Many Models Conversations**: Effortlessly engage with various models simultaneously, harnessing their unique strengths for optimal responses. Enhance your experience by leveraging a diverse set of models in parallel.
- 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators. - 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
- 🗄️ **Flexible Database & Storage Options**: Choose from SQLite (with optional encryption), PostgreSQL, or configure cloud storage backends (S3, Google Cloud Storage, Azure Blob Storage) for scalable deployments.
- 🔍 **Advanced Vector Database Support**: Select from 9 vector database options including ChromaDB, PGVector, Qdrant, Milvus, Elasticsearch, OpenSearch, Pinecone, S3Vector, and Oracle 23ai for optimal RAG performance.
- 🔐 **Enterprise Authentication**: Full support for LDAP/Active Directory integration, SCIM 2.0 automated provisioning, and SSO via trusted headers alongside OAuth providers. Enterprise-grade user and group provisioning through SCIM 2.0 protocol, enabling seamless integration with identity providers like Okta, Azure AD, and Google Workspace for automated user lifecycle management.
- ☁️ **Cloud-Native Integration**: Native support for Google Drive and OneDrive/SharePoint file picking, enabling seamless document import from enterprise cloud storage.
- 📊 **Production Observability**: Built-in OpenTelemetry support for traces, metrics, and logs, enabling comprehensive monitoring with your existing observability stack.
- ⚖️ **Horizontal Scalability**: Redis-backed session management and WebSocket support for multi-worker and multi-node deployments behind load balancers.
- 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors! - 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors!
- 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more. - 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more.
@ -79,6 +61,33 @@ For more information, be sure to check out our [Open WebUI Documentation](https:
Want to learn more about Open WebUI's features? Check out our [Open WebUI documentation](https://docs.openwebui.com/features) for a comprehensive overview! Want to learn more about Open WebUI's features? Check out our [Open WebUI documentation](https://docs.openwebui.com/features) for a comprehensive overview!
## Sponsors 🙌
#### Emerald
<table>
<tr>
<td>
<a href="https://n8n.io/" target="_blank">
<img src="https://docs.openwebui.com/sponsors/logos/n8n.png" alt="n8n" style="width: 8rem; height: 8rem; border-radius: .75rem;" />
</a>
</td>
<td>
N8N • Does your interface have a backend yet?<br>Try <a href="https://n8n.io/">n8n</a>
</td>
</tr>
<tr>
<td>
<a href="https://warp.dev/open-webui" target="_blank">
<img src="https://docs.openwebui.com/sponsors/logos/warp.png" alt="n8n" style="width: 8rem; height: 8rem; border-radius: .75rem;" />
</a>
</td>
<td>
Wrap • The intelligent terminal for developers
</td>
</tr>
</table>
--- ---
We are incredibly grateful for the generous support of our sponsors. Their contributions help us to maintain and improve our project, ensuring we can continue to deliver quality work to our community. Thank you! We are incredibly grateful for the generous support of our sponsors. Their contributions help us to maintain and improve our project, ensuring we can continue to deliver quality work to our community. Thank you!
@ -172,8 +181,6 @@ After installation, you can access Open WebUI at [http://localhost:3000](http://
We offer various installation alternatives, including non-Docker native installation methods, Docker Compose, Kustomize, and Helm. Visit our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/) or join our [Discord community](https://discord.gg/5rJgQTnV4s) for comprehensive guidance. We offer various installation alternatives, including non-Docker native installation methods, Docker Compose, Kustomize, and Helm. Visit our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/) or join our [Discord community](https://discord.gg/5rJgQTnV4s) for comprehensive guidance.
Look at the [Local Development Guide](https://docs.openwebui.com/getting-started/advanced-topics/development) for instructions on setting up a local development environment.
### Troubleshooting ### Troubleshooting
Encountering connection issues? Our [Open WebUI Documentation](https://docs.openwebui.com/troubleshooting/) has got you covered. For further assistance and to join our vibrant community, visit the [Open WebUI Discord](https://discord.gg/5rJgQTnV4s). Encountering connection issues? Our [Open WebUI Documentation](https://docs.openwebui.com/troubleshooting/) has got you covered. For further assistance and to join our vibrant community, visit the [Open WebUI Discord](https://discord.gg/5rJgQTnV4s).
@ -190,6 +197,14 @@ docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=
### Keeping Your Docker Installation Up-to-Date ### Keeping Your Docker Installation Up-to-Date
In case you want to update your local Docker installation to the latest version, you can do it with [Watchtower](https://containrrr.dev/watchtower/):
```bash
docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower --run-once open-webui
```
In the last part of the command, replace `open-webui` with your container name if it is different.
Check our Updating Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/updating). Check our Updating Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/updating).
### Using the Dev Branch 🌙 ### Using the Dev Branch 🌙
@ -217,7 +232,7 @@ Discover upcoming features on our roadmap in the [Open WebUI Documentation](http
## License 📜 ## License 📜
This project contains code under multiple licenses. The current codebase includes components licensed under the Open WebUI License with an additional requirement to preserve the "Open WebUI" branding, as well as prior contributions under their respective original licenses. For a detailed record of license changes and the applicable terms for each section of the code, please refer to [LICENSE_HISTORY](./LICENSE_HISTORY). For complete and updated licensing details, please see the [LICENSE](./LICENSE) and [LICENSE_HISTORY](./LICENSE_HISTORY) files. This project is licensed under the [Open WebUI License](LICENSE), a revised BSD-3-Clause license. You receive all the same rights as the classic BSD-3 license: you can use, modify, and distribute the software, including in proprietary and commercial products, with minimal restrictions. The only additional requirement is to preserve the "Open WebUI" branding, as detailed in the LICENSE file. For full terms, see the [LICENSE](LICENSE) document. 📄
## Support 💬 ## Support 💬

View file

@ -1,3 +1,2 @@
export CORS_ALLOW_ORIGIN="http://localhost:5173;http://localhost:8080"
PORT="${PORT:-8080}" PORT="${PORT:-8080}"
uvicorn open_webui.main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload uvicorn open_webui.main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload

View file

@ -10,7 +10,7 @@ script_location = migrations
# sys.path path, will be prepended to sys.path if present. # sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory. # defaults to the current working directory.
prepend_sys_path = .. prepend_sys_path = .
# timezone to use when rendering the date within the migration file # timezone to use when rendering the date within the migration file
# as well as the filename. # as well as the filename.

File diff suppressed because it is too large Load diff

View file

@ -38,14 +38,13 @@ class ERROR_MESSAGES(str, Enum):
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string." ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string." MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string." NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
MODEL_ID_TOO_LONG = "The model id is too long. Please make sure your model id is less than 256 characters long."
INVALID_TOKEN = ( INVALID_TOKEN = (
"Your session has expired or the token is invalid. Please sign in again." "Your session has expired or the token is invalid. Please sign in again."
) )
INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again." INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again."
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)." INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
INCORRECT_PASSWORD = ( INVALID_PASSWORD = (
"The password provided is incorrect. Please check for typos and try again." "The password provided is incorrect. Please check for typos and try again."
) )
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance." INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
@ -105,10 +104,6 @@ class ERROR_MESSAGES(str, Enum):
) )
FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding." FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding."
INVALID_PASSWORD = lambda err="": (
err if err else "The password does not meet the required validation criteria."
)
class TASKS(str, Enum): class TASKS(str, Enum):
def __str__(self) -> str: def __str__(self) -> str:
@ -116,7 +111,6 @@ class TASKS(str, Enum):
DEFAULT = lambda task="": f"{task if task else 'generation'}" DEFAULT = lambda task="": f"{task if task else 'generation'}"
TITLE_GENERATION = "title_generation" TITLE_GENERATION = "title_generation"
FOLLOW_UP_GENERATION = "follow_up_generation"
TAGS_GENERATION = "tags_generation" TAGS_GENERATION = "tags_generation"
EMOJI_GENERATION = "emoji_generation" EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "query_generation" QUERY_GENERATION = "query_generation"

View file

@ -5,11 +5,7 @@ import os
import pkgutil import pkgutil
import sys import sys
import shutil import shutil
from uuid import uuid4
from pathlib import Path from pathlib import Path
from cryptography.hazmat.primitives import serialization
import re
import markdown import markdown
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
@ -19,17 +15,14 @@ from open_webui.constants import ERROR_MESSAGES
# Load .env file # Load .env file
#################################### ####################################
# Use .resolve() to get the canonical path, removing any '..' or '.' components OPEN_WEBUI_DIR = Path(__file__).parent # the path containing this file
ENV_FILE_PATH = Path(__file__).resolve() print(OPEN_WEBUI_DIR)
# OPEN_WEBUI_DIR should be the directory where env.py resides (open_webui/) BACKEND_DIR = OPEN_WEBUI_DIR.parent # the path containing this file
OPEN_WEBUI_DIR = ENV_FILE_PATH.parent BASE_DIR = BACKEND_DIR.parent # the path containing the backend/
# BACKEND_DIR is the parent of OPEN_WEBUI_DIR (backend/) print(BACKEND_DIR)
BACKEND_DIR = OPEN_WEBUI_DIR.parent print(BASE_DIR)
# BASE_DIR is the parent of BACKEND_DIR (open-webui-dev/)
BASE_DIR = BACKEND_DIR.parent
try: try:
from dotenv import find_dotenv, load_dotenv from dotenv import find_dotenv, load_dotenv
@ -85,7 +78,32 @@ if "cuda_error" in locals():
log.exception(cuda_error) log.exception(cuda_error)
del cuda_error del cuda_error
SRC_LOG_LEVELS = {} # Legacy variable, do not remove log_sources = [
"AUDIO",
"COMFYUI",
"CONFIG",
"DB",
"IMAGES",
"MAIN",
"MODELS",
"OLLAMA",
"OPENAI",
"RAG",
"WEBHOOK",
"SOCKET",
"OAUTH",
]
SRC_LOG_LEVELS = {}
for source in log_sources:
log_env_var = source + "_LOG_LEVEL"
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
if SRC_LOG_LEVELS[source] not in logging.getLevelNamesMapping():
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI": if WEBUI_NAME != "Open WebUI":
@ -114,10 +132,6 @@ else:
VERSION = PACKAGE_DATA["version"] VERSION = PACKAGE_DATA["version"]
DEPLOYMENT_ID = os.environ.get("DEPLOYMENT_ID", "")
INSTANCE_ID = os.environ.get("INSTANCE_ID", str(uuid4()))
# Function to parse each section # Function to parse each section
def parse_section(section): def parse_section(section):
items = [] items = []
@ -183,7 +197,6 @@ CHANGELOG = changelog_json
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true" SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
#################################### ####################################
# ENABLE_FORWARD_USER_INFO_HEADERS # ENABLE_FORWARD_USER_INFO_HEADERS
#################################### ####################################
@ -192,11 +205,6 @@ ENABLE_FORWARD_USER_INFO_HEADERS = (
os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true" os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true"
) )
# Experimental feature, may be removed in future
ENABLE_STAR_SESSIONS_MIDDLEWARE = (
os.environ.get("ENABLE_STAR_SESSIONS_MIDDLEWARE", "False").lower() == "true"
)
#################################### ####################################
# WEBUI_BUILD_HASH # WEBUI_BUILD_HASH
#################################### ####################################
@ -256,43 +264,21 @@ else:
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db") DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
DATABASE_TYPE = os.environ.get("DATABASE_TYPE")
DATABASE_USER = os.environ.get("DATABASE_USER")
DATABASE_PASSWORD = os.environ.get("DATABASE_PASSWORD")
DATABASE_CRED = ""
if DATABASE_USER:
DATABASE_CRED += f"{DATABASE_USER}"
if DATABASE_PASSWORD:
DATABASE_CRED += f":{DATABASE_PASSWORD}"
DB_VARS = {
"db_type": DATABASE_TYPE,
"db_cred": DATABASE_CRED,
"db_host": os.environ.get("DATABASE_HOST"),
"db_port": os.environ.get("DATABASE_PORT"),
"db_name": os.environ.get("DATABASE_NAME"),
}
if all(DB_VARS.values()):
DATABASE_URL = f"{DB_VARS['db_type']}://{DB_VARS['db_cred']}@{DB_VARS['db_host']}:{DB_VARS['db_port']}/{DB_VARS['db_name']}"
elif DATABASE_TYPE == "sqlite+sqlcipher" and not os.environ.get("DATABASE_URL"):
# Handle SQLCipher with local file when DATABASE_URL wasn't explicitly set
DATABASE_URL = f"sqlite+sqlcipher:///{DATA_DIR}/webui.db"
# Replace the postgres:// with postgresql:// # Replace the postgres:// with postgresql://
if "postgres://" in DATABASE_URL: if "postgres://" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://") DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")
DATABASE_SCHEMA = os.environ.get("DATABASE_SCHEMA", None) DATABASE_SCHEMA = os.environ.get("DATABASE_SCHEMA", None)
DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", None) DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", 0)
if DATABASE_POOL_SIZE != None: if DATABASE_POOL_SIZE == "":
DATABASE_POOL_SIZE = 0
else:
try: try:
DATABASE_POOL_SIZE = int(DATABASE_POOL_SIZE) DATABASE_POOL_SIZE = int(DATABASE_POOL_SIZE)
except Exception: except Exception:
DATABASE_POOL_SIZE = None DATABASE_POOL_SIZE = 0
DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0) DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0)
@ -324,26 +310,6 @@ else:
except Exception: except Exception:
DATABASE_POOL_RECYCLE = 3600 DATABASE_POOL_RECYCLE = 3600
DATABASE_ENABLE_SQLITE_WAL = (
os.environ.get("DATABASE_ENABLE_SQLITE_WAL", "False").lower() == "true"
)
DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = os.environ.get(
"DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL", None
)
if DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL is not None:
try:
DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = float(
DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
)
except Exception:
DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = 0.0
# Enable public visibility of active user count (when disabled, only admins can see it)
ENABLE_PUBLIC_ACTIVE_USERS_COUNT = (
os.environ.get("ENABLE_PUBLIC_ACTIVE_USERS_COUNT", "True").lower() == "true"
)
RESET_CONFIG_ON_START = ( RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true" os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
) )
@ -352,36 +318,14 @@ ENABLE_REALTIME_CHAT_SAVE = (
os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "False").lower() == "true" os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "False").lower() == "true"
) )
ENABLE_QUERIES_CACHE = os.environ.get("ENABLE_QUERIES_CACHE", "False").lower() == "true"
#################################### ####################################
# REDIS # REDIS
#################################### ####################################
REDIS_URL = os.environ.get("REDIS_URL", "") REDIS_URL = os.environ.get("REDIS_URL", "")
REDIS_CLUSTER = os.environ.get("REDIS_CLUSTER", "False").lower() == "true"
REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui")
REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "") REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379") REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
# Maximum number of retries for Redis operations when using Sentinel fail-over
REDIS_SENTINEL_MAX_RETRY_COUNT = os.environ.get("REDIS_SENTINEL_MAX_RETRY_COUNT", "2")
try:
REDIS_SENTINEL_MAX_RETRY_COUNT = int(REDIS_SENTINEL_MAX_RETRY_COUNT)
if REDIS_SENTINEL_MAX_RETRY_COUNT < 1:
REDIS_SENTINEL_MAX_RETRY_COUNT = 2
except ValueError:
REDIS_SENTINEL_MAX_RETRY_COUNT = 2
REDIS_SOCKET_CONNECT_TIMEOUT = os.environ.get("REDIS_SOCKET_CONNECT_TIMEOUT", "")
try:
REDIS_SOCKET_CONNECT_TIMEOUT = float(REDIS_SOCKET_CONNECT_TIMEOUT)
except ValueError:
REDIS_SOCKET_CONNECT_TIMEOUT = None
#################################### ####################################
# UVICORN WORKERS # UVICORN WORKERS
#################################### ####################################
@ -401,14 +345,6 @@ except ValueError:
#################################### ####################################
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
ENABLE_INITIAL_ADMIN_SIGNUP = (
os.environ.get("ENABLE_INITIAL_ADMIN_SIGNUP", "False").lower() == "true"
)
ENABLE_SIGNUP_PASSWORD_CONFIRMATION = (
os.environ.get("ENABLE_SIGNUP_PASSWORD_CONFIRMATION", "False").lower() == "true"
)
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
) )
@ -418,23 +354,6 @@ WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get(
) )
ENABLE_PASSWORD_VALIDATION = (
os.environ.get("ENABLE_PASSWORD_VALIDATION", "False").lower() == "true"
)
PASSWORD_VALIDATION_REGEX_PATTERN = os.environ.get(
"PASSWORD_VALIDATION_REGEX_PATTERN",
"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$",
)
try:
PASSWORD_VALIDATION_REGEX_PATTERN = re.compile(PASSWORD_VALIDATION_REGEX_PATTERN)
except Exception as e:
log.error(f"Invalid PASSWORD_VALIDATION_REGEX_PATTERN: {e}")
PASSWORD_VALIDATION_REGEX_PATTERN = re.compile(
"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$"
)
BYPASS_MODEL_ACCESS_CONTROL = ( BYPASS_MODEL_ACCESS_CONTROL = (
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true" os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
) )
@ -475,197 +394,18 @@ WEBUI_AUTH_COOKIE_SECURE = (
if WEBUI_AUTH and WEBUI_SECRET_KEY == "": if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
ENABLE_COMPRESSION_MIDDLEWARE = (
os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true"
)
####################################
# OAUTH Configuration
####################################
ENABLE_OAUTH_EMAIL_FALLBACK = (
os.environ.get("ENABLE_OAUTH_EMAIL_FALLBACK", "False").lower() == "true"
)
ENABLE_OAUTH_ID_TOKEN_COOKIE = (
os.environ.get("ENABLE_OAUTH_ID_TOKEN_COOKIE", "True").lower() == "true"
)
OAUTH_CLIENT_INFO_ENCRYPTION_KEY = os.environ.get(
"OAUTH_CLIENT_INFO_ENCRYPTION_KEY", WEBUI_SECRET_KEY
)
OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get(
"OAUTH_SESSION_TOKEN_ENCRYPTION_KEY", WEBUI_SECRET_KEY
)
####################################
# SCIM Configuration
####################################
ENABLE_SCIM = (
os.environ.get("ENABLE_SCIM", os.environ.get("SCIM_ENABLED", "False")).lower()
== "true"
)
SCIM_TOKEN = os.environ.get("SCIM_TOKEN", "")
####################################
# LICENSE_KEY
####################################
LICENSE_KEY = os.environ.get("LICENSE_KEY", "")
LICENSE_BLOB = None
LICENSE_BLOB_PATH = os.environ.get("LICENSE_BLOB_PATH", DATA_DIR / "l.data")
if LICENSE_BLOB_PATH and os.path.exists(LICENSE_BLOB_PATH):
with open(LICENSE_BLOB_PATH, "rb") as f:
LICENSE_BLOB = f.read()
LICENSE_PUBLIC_KEY = os.environ.get("LICENSE_PUBLIC_KEY", "")
pk = None
if LICENSE_PUBLIC_KEY:
pk = serialization.load_pem_public_key(
f"""
-----BEGIN PUBLIC KEY-----
{LICENSE_PUBLIC_KEY}
-----END PUBLIC KEY-----
""".encode(
"utf-8"
)
)
####################################
# MODELS
####################################
ENABLE_CUSTOM_MODEL_FALLBACK = (
os.environ.get("ENABLE_CUSTOM_MODEL_FALLBACK", "False").lower() == "true"
)
MODELS_CACHE_TTL = os.environ.get("MODELS_CACHE_TTL", "1")
if MODELS_CACHE_TTL == "":
MODELS_CACHE_TTL = None
else:
try:
MODELS_CACHE_TTL = int(MODELS_CACHE_TTL)
except Exception:
MODELS_CACHE_TTL = 1
####################################
# CHAT
####################################
ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION = (
os.environ.get("ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION", "False").lower()
== "true"
)
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = os.environ.get(
"CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE", "1"
)
if CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE == "":
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = 1
else:
try:
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = int(
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE
)
except Exception:
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = 1
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = os.environ.get(
"CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES", "30"
)
if CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES == "":
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30
else:
try:
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = int(CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES)
except Exception:
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = os.environ.get(
"CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE", ""
)
if CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE == "":
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = None
else:
try:
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = int(
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE
)
except Exception:
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = None
####################################
# WEBSOCKET SUPPORT
####################################
ENABLE_WEBSOCKET_SUPPORT = ( ENABLE_WEBSOCKET_SUPPORT = (
os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true" os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true"
) )
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "") WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_OPTIONS = os.environ.get("WEBSOCKET_REDIS_OPTIONS", "")
if WEBSOCKET_REDIS_OPTIONS == "":
if REDIS_SOCKET_CONNECT_TIMEOUT:
WEBSOCKET_REDIS_OPTIONS = {
"socket_connect_timeout": REDIS_SOCKET_CONNECT_TIMEOUT
}
else:
log.debug("No WEBSOCKET_REDIS_OPTIONS provided, defaulting to None")
WEBSOCKET_REDIS_OPTIONS = None
else:
try:
WEBSOCKET_REDIS_OPTIONS = json.loads(WEBSOCKET_REDIS_OPTIONS)
except Exception:
log.warning("Invalid WEBSOCKET_REDIS_OPTIONS, defaulting to None")
WEBSOCKET_REDIS_OPTIONS = None
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL) WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
WEBSOCKET_REDIS_CLUSTER = ( WEBSOCKET_REDIS_LOCK_TIMEOUT = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", 60)
os.environ.get("WEBSOCKET_REDIS_CLUSTER", str(REDIS_CLUSTER)).lower() == "true"
)
websocket_redis_lock_timeout = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", "60")
try:
WEBSOCKET_REDIS_LOCK_TIMEOUT = int(websocket_redis_lock_timeout)
except ValueError:
WEBSOCKET_REDIS_LOCK_TIMEOUT = 60
WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "") WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379") WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379")
WEBSOCKET_SERVER_LOGGING = (
os.environ.get("WEBSOCKET_SERVER_LOGGING", "False").lower() == "true"
)
WEBSOCKET_SERVER_ENGINEIO_LOGGING = (
os.environ.get("WEBSOCKET_SERVER_LOGGING", "False").lower() == "true"
)
WEBSOCKET_SERVER_PING_TIMEOUT = os.environ.get("WEBSOCKET_SERVER_PING_TIMEOUT", "20")
try:
WEBSOCKET_SERVER_PING_TIMEOUT = int(WEBSOCKET_SERVER_PING_TIMEOUT)
except ValueError:
WEBSOCKET_SERVER_PING_TIMEOUT = 20
WEBSOCKET_SERVER_PING_INTERVAL = os.environ.get("WEBSOCKET_SERVER_PING_INTERVAL", "25")
try:
WEBSOCKET_SERVER_PING_INTERVAL = int(WEBSOCKET_SERVER_PING_INTERVAL)
except ValueError:
WEBSOCKET_SERVER_PING_INTERVAL = 25
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
@ -764,32 +504,19 @@ else:
# OFFLINE_MODE # OFFLINE_MODE
#################################### ####################################
ENABLE_VERSION_UPDATE_CHECK = (
os.environ.get("ENABLE_VERSION_UPDATE_CHECK", "true").lower() == "true"
)
OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true" OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
if OFFLINE_MODE: if OFFLINE_MODE:
os.environ["HF_HUB_OFFLINE"] = "1" os.environ["HF_HUB_OFFLINE"] = "1"
ENABLE_VERSION_UPDATE_CHECK = False
#################################### ####################################
# AUDIT LOGGING # AUDIT LOGGING
#################################### ####################################
# Where to store log file # Where to store log file
# Defaults to the DATA_DIR/audit.log. To set AUDIT_LOGS_FILE_PATH you need to AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
# provide the whole path, like: /app/audit.log
AUDIT_LOGS_FILE_PATH = os.getenv("AUDIT_LOGS_FILE_PATH", f"{DATA_DIR}/audit.log")
# Maximum size of a file before rotating into a new log file # Maximum size of a file before rotating into a new log file
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB") AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
# Comma separated list of logger names to use for audit logging
# Default is "uvicorn.access" which is the access log for Uvicorn
# You can add more logger names to this list if you want to capture more logs
AUDIT_UVICORN_LOGGER_NAMES = os.getenv(
"AUDIT_UVICORN_LOGGER_NAMES", "uvicorn.access"
).split(",")
# METADATA | REQUEST | REQUEST_RESPONSE # METADATA | REQUEST | REQUEST_RESPONSE
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper() AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper()
try: try:
@ -810,34 +537,9 @@ AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
#################################### ####################################
ENABLE_OTEL = os.environ.get("ENABLE_OTEL", "False").lower() == "true" ENABLE_OTEL = os.environ.get("ENABLE_OTEL", "False").lower() == "true"
ENABLE_OTEL_TRACES = os.environ.get("ENABLE_OTEL_TRACES", "False").lower() == "true"
ENABLE_OTEL_METRICS = os.environ.get("ENABLE_OTEL_METRICS", "False").lower() == "true"
ENABLE_OTEL_LOGS = os.environ.get("ENABLE_OTEL_LOGS", "False").lower() == "true"
OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get( OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317" "OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317"
) )
OTEL_METRICS_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_METRICS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT
)
OTEL_LOGS_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_LOGS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT
)
OTEL_EXPORTER_OTLP_INSECURE = (
os.environ.get("OTEL_EXPORTER_OTLP_INSECURE", "False").lower() == "true"
)
OTEL_METRICS_EXPORTER_OTLP_INSECURE = (
os.environ.get(
"OTEL_METRICS_EXPORTER_OTLP_INSECURE", str(OTEL_EXPORTER_OTLP_INSECURE)
).lower()
== "true"
)
OTEL_LOGS_EXPORTER_OTLP_INSECURE = (
os.environ.get(
"OTEL_LOGS_EXPORTER_OTLP_INSECURE", str(OTEL_EXPORTER_OTLP_INSECURE)
).lower()
== "true"
)
OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui") OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui")
OTEL_RESOURCE_ATTRIBUTES = os.environ.get( OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
"OTEL_RESOURCE_ATTRIBUTES", "" "OTEL_RESOURCE_ATTRIBUTES", ""
@ -845,33 +547,6 @@ OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
OTEL_TRACES_SAMPLER = os.environ.get( OTEL_TRACES_SAMPLER = os.environ.get(
"OTEL_TRACES_SAMPLER", "parentbased_always_on" "OTEL_TRACES_SAMPLER", "parentbased_always_on"
).lower() ).lower()
OTEL_BASIC_AUTH_USERNAME = os.environ.get("OTEL_BASIC_AUTH_USERNAME", "")
OTEL_BASIC_AUTH_PASSWORD = os.environ.get("OTEL_BASIC_AUTH_PASSWORD", "")
OTEL_METRICS_BASIC_AUTH_USERNAME = os.environ.get(
"OTEL_METRICS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME
)
OTEL_METRICS_BASIC_AUTH_PASSWORD = os.environ.get(
"OTEL_METRICS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD
)
OTEL_LOGS_BASIC_AUTH_USERNAME = os.environ.get(
"OTEL_LOGS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME
)
OTEL_LOGS_BASIC_AUTH_PASSWORD = os.environ.get(
"OTEL_LOGS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD
)
OTEL_OTLP_SPAN_EXPORTER = os.environ.get(
"OTEL_OTLP_SPAN_EXPORTER", "grpc"
).lower() # grpc or http
OTEL_METRICS_OTLP_SPAN_EXPORTER = os.environ.get(
"OTEL_METRICS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER
).lower() # grpc or http
OTEL_LOGS_OTLP_SPAN_EXPORTER = os.environ.get(
"OTEL_LOGS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER
).lower() # grpc or http
#################################### ####################################
# TOOLS/FUNCTIONS PIP OPTIONS # TOOLS/FUNCTIONS PIP OPTIONS

View file

@ -19,14 +19,12 @@ from fastapi import (
from starlette.responses import Response, StreamingResponse from starlette.responses import Response, StreamingResponse
from open_webui.constants import ERROR_MESSAGES
from open_webui.socket.main import ( from open_webui.socket.main import (
get_event_call, get_event_call,
get_event_emitter, get_event_emitter,
) )
from open_webui.models.users import UserModel
from open_webui.models.functions import Functions from open_webui.models.functions import Functions
from open_webui.models.models import Models from open_webui.models.models import Models
@ -37,7 +35,7 @@ from open_webui.utils.plugin import (
from open_webui.utils.tools import get_tools from open_webui.utils.tools import get_tools
from open_webui.utils.access_control import has_access from open_webui.utils.access_control import has_access
from open_webui.env import GLOBAL_LOG_LEVEL from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
from open_webui.utils.misc import ( from open_webui.utils.misc import (
add_or_update_system_message, add_or_update_system_message,
@ -48,32 +46,21 @@ from open_webui.utils.misc import (
) )
from open_webui.utils.payload import ( from open_webui.utils.payload import (
apply_model_params_to_body_openai, apply_model_params_to_body_openai,
apply_system_prompt_to_body, apply_model_system_prompt_to_body,
) )
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_function_module_by_id(request: Request, pipe_id: str): def get_function_module_by_id(request: Request, pipe_id: str):
function_module, _, _ = get_function_module_from_cache(request, pipe_id) function_module, _, _ = get_function_module_from_cache(request, pipe_id)
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
Valves = function_module.Valves
valves = Functions.get_function_valves_by_id(pipe_id) valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
if valves:
try:
function_module.valves = Valves(
**{k: v for k, v in valves.items() if v is not None}
)
except Exception as e:
log.exception(f"Error loading valves for function {pipe_id}: {e}")
raise e
else:
function_module.valves = Valves()
return function_module return function_module
@ -82,75 +69,65 @@ async def get_function_models(request):
pipe_models = [] pipe_models = []
for pipe in pipes: for pipe in pipes:
try: function_module = get_function_module_by_id(request, pipe.id)
function_module = get_function_module_by_id(request, pipe.id)
has_user_valves = False # Check if function is a manifold
if hasattr(function_module, "UserValves"): if hasattr(function_module, "pipes"):
has_user_valves = True sub_pipes = []
# Check if function is a manifold # Handle pipes being a list, sync function, or async function
if hasattr(function_module, "pipes"): try:
if callable(function_module.pipes):
if asyncio.iscoroutinefunction(function_module.pipes):
sub_pipes = await function_module.pipes()
else:
sub_pipes = function_module.pipes()
else:
sub_pipes = function_module.pipes
except Exception as e:
log.exception(e)
sub_pipes = [] sub_pipes = []
# Handle pipes being a list, sync function, or async function log.debug(
try: f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
if callable(function_module.pipes): )
if asyncio.iscoroutinefunction(function_module.pipes):
sub_pipes = await function_module.pipes()
else:
sub_pipes = function_module.pipes()
else:
sub_pipes = function_module.pipes
except Exception as e:
log.exception(e)
sub_pipes = []
log.debug( for p in sub_pipes:
f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}" sub_pipe_id = f'{pipe.id}.{p["id"]}'
) sub_pipe_name = p["name"]
for p in sub_pipes: if hasattr(function_module, "name"):
sub_pipe_id = f'{pipe.id}.{p["id"]}' sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
sub_pipe_name = p["name"]
if hasattr(function_module, "name"): pipe_flag = {"type": pipe.type}
sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
pipe_flag = {"type": pipe.type}
pipe_models.append(
{
"id": sub_pipe_id,
"name": sub_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
"has_user_valves": has_user_valves,
}
)
else:
pipe_flag = {"type": "pipe"}
log.debug(
f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
)
pipe_models.append( pipe_models.append(
{ {
"id": pipe.id, "id": sub_pipe_id,
"name": pipe.name, "name": sub_pipe_name,
"object": "model", "object": "model",
"created": pipe.created_at, "created": pipe.created_at,
"owned_by": "openai", "owned_by": "openai",
"pipe": pipe_flag, "pipe": pipe_flag,
"has_user_valves": has_user_valves,
} }
) )
except Exception as e: else:
log.exception(e) pipe_flag = {"type": "pipe"}
continue
log.debug(
f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
)
pipe_models.append(
{
"id": pipe.id,
"name": pipe.name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
return pipe_models return pipe_models
@ -241,16 +218,6 @@ async def generate_function_chat_completion(
__task__ = metadata.get("task", None) __task__ = metadata.get("task", None)
__task_body__ = metadata.get("task_body", None) __task_body__ = metadata.get("task_body", None)
oauth_token = None
try:
if request.cookies.get("oauth_session_id", None):
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
except Exception as e:
log.error(f"Error getting OAuth token: {e}")
extra_params = { extra_params = {
"__event_emitter__": __event_emitter__, "__event_emitter__": __event_emitter__,
"__event_call__": __event_call__, "__event_call__": __event_call__,
@ -260,12 +227,16 @@ async def generate_function_chat_completion(
"__task__": __task__, "__task__": __task__,
"__task_body__": __task_body__, "__task_body__": __task_body__,
"__files__": files, "__files__": files,
"__user__": user.model_dump() if isinstance(user, UserModel) else {}, "__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__metadata__": metadata, "__metadata__": metadata,
"__oauth_token__": oauth_token,
"__request__": request, "__request__": request,
} }
extra_params["__tools__"] = await get_tools( extra_params["__tools__"] = get_tools(
request, request,
tool_ids, tool_ids,
user, user,
@ -282,11 +253,8 @@ async def generate_function_chat_completion(
form_data["model"] = model_info.base_model_id form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump() params = model_info.params.model_dump()
form_data = apply_model_params_to_body_openai(params, form_data)
if params: form_data = apply_model_system_prompt_to_body(params, form_data, metadata, user)
system = params.pop("system", None)
form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_system_prompt_to_body(system, form_data, metadata, user)
pipe_id = get_pipe_id(form_data) pipe_id = get_pipe_id(form_data)
function_module = get_function_module_by_id(request, pipe_id) function_module = get_function_module_by_id(request, pipe_id)

View file

@ -1,4 +1,3 @@
import os
import json import json
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
@ -9,14 +8,14 @@ from open_webui.env import (
OPEN_WEBUI_DIR, OPEN_WEBUI_DIR,
DATABASE_URL, DATABASE_URL,
DATABASE_SCHEMA, DATABASE_SCHEMA,
SRC_LOG_LEVELS,
DATABASE_POOL_MAX_OVERFLOW, DATABASE_POOL_MAX_OVERFLOW,
DATABASE_POOL_RECYCLE, DATABASE_POOL_RECYCLE,
DATABASE_POOL_SIZE, DATABASE_POOL_SIZE,
DATABASE_POOL_TIMEOUT, DATABASE_POOL_TIMEOUT,
DATABASE_ENABLE_SQLITE_WAL,
) )
from peewee_migrate import Router from peewee_migrate import Router
from sqlalchemy import Dialect, create_engine, MetaData, event, types from sqlalchemy import Dialect, create_engine, MetaData, types
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.pool import QueuePool, NullPool from sqlalchemy.pool import QueuePool, NullPool
@ -24,6 +23,7 @@ from sqlalchemy.sql.type_api import _T
from typing_extensions import Self from typing_extensions import Self
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
class JSONField(types.TypeDecorator): class JSONField(types.TypeDecorator):
@ -62,9 +62,6 @@ def handle_peewee_migration(DATABASE_URL):
except Exception as e: except Exception as e:
log.error(f"Failed to initialize the database connection: {e}") log.error(f"Failed to initialize the database connection: {e}")
log.warning(
"Hint: If your database password contains special characters, you may need to URL-encode it."
)
raise raise
finally: finally:
# Properly closing the database connection # Properly closing the database connection
@ -79,66 +76,25 @@ handle_peewee_migration(DATABASE_URL)
SQLALCHEMY_DATABASE_URL = DATABASE_URL SQLALCHEMY_DATABASE_URL = DATABASE_URL
if "sqlite" in SQLALCHEMY_DATABASE_URL:
# Handle SQLCipher URLs
if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
database_password = os.environ.get("DATABASE_PASSWORD")
if not database_password or database_password.strip() == "":
raise ValueError(
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
)
# Extract database path from SQLCipher URL
db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "")
# Create a custom creator function that uses sqlcipher3
def create_sqlcipher_connection():
import sqlcipher3
conn = sqlcipher3.connect(db_path, check_same_thread=False)
conn.execute(f"PRAGMA key = '{database_password}'")
return conn
engine = create_engine(
"sqlite://", # Dummy URL since we're using creator
creator=create_sqlcipher_connection,
echo=False,
)
log.info("Connected to encrypted SQLite database using SQLCipher")
elif "sqlite" in SQLALCHEMY_DATABASE_URL:
engine = create_engine( engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
) )
def on_connect(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
if DATABASE_ENABLE_SQLITE_WAL:
cursor.execute("PRAGMA journal_mode=WAL")
else:
cursor.execute("PRAGMA journal_mode=DELETE")
cursor.close()
event.listen(engine, "connect", on_connect)
else: else:
if isinstance(DATABASE_POOL_SIZE, int): if DATABASE_POOL_SIZE > 0:
if DATABASE_POOL_SIZE > 0: engine = create_engine(
engine = create_engine( SQLALCHEMY_DATABASE_URL,
SQLALCHEMY_DATABASE_URL, pool_size=DATABASE_POOL_SIZE,
pool_size=DATABASE_POOL_SIZE, max_overflow=DATABASE_POOL_MAX_OVERFLOW,
max_overflow=DATABASE_POOL_MAX_OVERFLOW, pool_timeout=DATABASE_POOL_TIMEOUT,
pool_timeout=DATABASE_POOL_TIMEOUT, pool_recycle=DATABASE_POOL_RECYCLE,
pool_recycle=DATABASE_POOL_RECYCLE, pool_pre_ping=True,
pool_pre_ping=True, poolclass=QueuePool,
poolclass=QueuePool, )
)
else:
engine = create_engine(
SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
)
else: else:
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) engine = create_engine(
SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
)
SessionLocal = sessionmaker( SessionLocal = sessionmaker(

View file

@ -1,7 +1,7 @@
import logging import logging
import os
from contextvars import ContextVar from contextvars import ContextVar
from open_webui.env import SRC_LOG_LEVELS
from peewee import * from peewee import *
from peewee import InterfaceError as PeeWeeInterfaceError from peewee import InterfaceError as PeeWeeInterfaceError
from peewee import PostgresqlDatabase from peewee import PostgresqlDatabase
@ -9,6 +9,7 @@ from playhouse.db_url import connect, parse
from playhouse.shortcuts import ReconnectMixin from playhouse.shortcuts import ReconnectMixin
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None} db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy()) db_state = ContextVar("db_state", default=db_state_default.copy())
@ -42,45 +43,24 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
def register_connection(db_url): def register_connection(db_url):
# Check if using SQLCipher protocol db = connect(db_url, unquote_user=True, unquote_password=True)
if db_url.startswith("sqlite+sqlcipher://"): if isinstance(db, PostgresqlDatabase):
database_password = os.environ.get("DATABASE_PASSWORD") # Enable autoconnect for SQLite databases, managed by Peewee
if not database_password or database_password.strip() == "":
raise ValueError(
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
)
from playhouse.sqlcipher_ext import SqlCipherDatabase
# Parse the database path from SQLCipher URL
# Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite
db_path = db_url.replace("sqlite+sqlcipher://", "")
# Use Peewee's native SqlCipherDatabase with encryption
db = SqlCipherDatabase(db_path, passphrase=database_password)
db.autoconnect = True db.autoconnect = True
db.reuse_if_open = True db.reuse_if_open = True
log.info("Connected to encrypted SQLite database using SQLCipher") log.info("Connected to PostgreSQL database")
# Get the connection details
connection = parse(db_url, unquote_user=True, unquote_password=True)
# Use our custom database class that supports reconnection
db = ReconnectingPostgresqlDatabase(**connection)
db.connect(reuse_if_open=True)
elif isinstance(db, SqliteDatabase):
# Enable autoconnect for SQLite databases, managed by Peewee
db.autoconnect = True
db.reuse_if_open = True
log.info("Connected to SQLite database")
else: else:
# Standard database connection (existing logic) raise ValueError("Unsupported database connection")
db = connect(db_url, unquote_user=True, unquote_password=True)
if isinstance(db, PostgresqlDatabase):
# Enable autoconnect for SQLite databases, managed by Peewee
db.autoconnect = True
db.reuse_if_open = True
log.info("Connected to PostgreSQL database")
# Get the connection details
connection = parse(db_url, unquote_user=True, unquote_password=True)
# Use our custom database class that supports reconnection
db = ReconnectingPostgresqlDatabase(**connection)
db.connect(reuse_if_open=True)
elif isinstance(db, SqliteDatabase):
# Enable autoconnect for SQLite databases, managed by Peewee
db.autoconnect = True
db.reuse_if_open = True
log.info("Connected to SQLite database")
else:
raise ValueError("Unsupported database connection")
return db return db

File diff suppressed because it is too large Load diff

View file

@ -2,8 +2,8 @@ from logging.config import fileConfig
from alembic import context from alembic import context
from open_webui.models.auths import Auth from open_webui.models.auths import Auth
from open_webui.env import DATABASE_URL, DATABASE_PASSWORD from open_webui.env import DATABASE_URL
from sqlalchemy import engine_from_config, pool, create_engine from sqlalchemy import engine_from_config, pool
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
@ -62,38 +62,11 @@ def run_migrations_online() -> None:
and associate a connection with the context. and associate a connection with the context.
""" """
# Handle SQLCipher URLs connectable = engine_from_config(
if DB_URL and DB_URL.startswith("sqlite+sqlcipher://"): config.get_section(config.config_ini_section, {}),
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "": prefix="sqlalchemy.",
raise ValueError( poolclass=pool.NullPool,
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs" )
)
# Extract database path from SQLCipher URL
db_path = DB_URL.replace("sqlite+sqlcipher://", "")
if db_path.startswith("/"):
db_path = db_path[1:] # Remove leading slash for relative paths
# Create a custom creator function that uses sqlcipher3
def create_sqlcipher_connection():
import sqlcipher3
conn = sqlcipher3.connect(db_path, check_same_thread=False)
conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'")
return conn
connectable = create_engine(
"sqlite://", # Dummy URL since we're using creator
creator=create_sqlcipher_connection,
echo=False,
)
else:
# Standard database connection (existing logic)
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata) context.configure(connection=connection, target_metadata=target_metadata)

View file

@ -1,46 +0,0 @@
"""Add indexes
Revision ID: 018012973d35
Revises: d31026856c01
Create Date: 2025-08-13 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "018012973d35"
down_revision = "d31026856c01"
branch_labels = None
depends_on = None
def upgrade():
# Chat table indexes
op.create_index("folder_id_idx", "chat", ["folder_id"])
op.create_index("user_id_pinned_idx", "chat", ["user_id", "pinned"])
op.create_index("user_id_archived_idx", "chat", ["user_id", "archived"])
op.create_index("updated_at_user_id_idx", "chat", ["updated_at", "user_id"])
op.create_index("folder_id_user_id_idx", "chat", ["folder_id", "user_id"])
# Tag table index
op.create_index("user_id_idx", "tag", ["user_id"])
# Function table index
op.create_index("is_global_idx", "function", ["is_global"])
def downgrade():
# Chat table indexes
op.drop_index("folder_id_idx", table_name="chat")
op.drop_index("user_id_pinned_idx", table_name="chat")
op.drop_index("user_id_archived_idx", table_name="chat")
op.drop_index("updated_at_user_id_idx", table_name="chat")
op.drop_index("folder_id_user_id_idx", table_name="chat")
# Tag table index
op.drop_index("user_id_idx", table_name="tag")
# Function table index
op.drop_index("is_global_idx", table_name="function")

View file

@ -1,103 +0,0 @@
"""Update messages and channel member table
Revision ID: 2f1211949ecc
Revises: 37f288994c47
Create Date: 2025-11-27 03:07:56.200231
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
# revision identifiers, used by Alembic.
revision: str = "2f1211949ecc"
down_revision: Union[str, None] = "37f288994c47"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# New columns to be added to channel_member table
op.add_column("channel_member", sa.Column("status", sa.Text(), nullable=True))
op.add_column(
"channel_member",
sa.Column(
"is_active",
sa.Boolean(),
nullable=False,
default=True,
server_default=sa.sql.expression.true(),
),
)
op.add_column(
"channel_member",
sa.Column(
"is_channel_muted",
sa.Boolean(),
nullable=False,
default=False,
server_default=sa.sql.expression.false(),
),
)
op.add_column(
"channel_member",
sa.Column(
"is_channel_pinned",
sa.Boolean(),
nullable=False,
default=False,
server_default=sa.sql.expression.false(),
),
)
op.add_column("channel_member", sa.Column("data", sa.JSON(), nullable=True))
op.add_column("channel_member", sa.Column("meta", sa.JSON(), nullable=True))
op.add_column(
"channel_member", sa.Column("joined_at", sa.BigInteger(), nullable=False)
)
op.add_column(
"channel_member", sa.Column("left_at", sa.BigInteger(), nullable=True)
)
op.add_column(
"channel_member", sa.Column("last_read_at", sa.BigInteger(), nullable=True)
)
op.add_column(
"channel_member", sa.Column("updated_at", sa.BigInteger(), nullable=True)
)
# New columns to be added to message table
op.add_column(
"message",
sa.Column(
"is_pinned",
sa.Boolean(),
nullable=False,
default=False,
server_default=sa.sql.expression.false(),
),
)
op.add_column("message", sa.Column("pinned_at", sa.BigInteger(), nullable=True))
op.add_column("message", sa.Column("pinned_by", sa.Text(), nullable=True))
def downgrade() -> None:
op.drop_column("channel_member", "updated_at")
op.drop_column("channel_member", "last_read_at")
op.drop_column("channel_member", "meta")
op.drop_column("channel_member", "data")
op.drop_column("channel_member", "is_channel_pinned")
op.drop_column("channel_member", "is_channel_muted")
op.drop_column("message", "pinned_by")
op.drop_column("message", "pinned_at")
op.drop_column("message", "is_pinned")

View file

@ -1,146 +0,0 @@
"""add_group_member_table
Revision ID: 37f288994c47
Revises: a5c220713937
Create Date: 2025-11-17 03:45:25.123939
"""
import uuid
import time
import json
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "37f288994c47"
down_revision: Union[str, None] = "a5c220713937"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# 1. Create new table
op.create_table(
"group_member",
sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False),
sa.Column(
"group_id",
sa.Text(),
sa.ForeignKey("group.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"user_id",
sa.Text(),
sa.ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.UniqueConstraint("group_id", "user_id", name="uq_group_member_group_user"),
)
connection = op.get_bind()
# 2. Read existing group with user_ids JSON column
group_table = sa.Table(
"group",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("user_ids", sa.JSON()), # JSON stored as text in SQLite + PG
)
results = connection.execute(
sa.select(group_table.c.id, group_table.c.user_ids)
).fetchall()
print(results)
# 3. Insert members into group_member table
gm_table = sa.Table(
"group_member",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("group_id", sa.Text()),
sa.Column("user_id", sa.Text()),
sa.Column("created_at", sa.BigInteger()),
sa.Column("updated_at", sa.BigInteger()),
)
now = int(time.time())
for group_id, user_ids in results:
if not user_ids:
continue
if isinstance(user_ids, str):
try:
user_ids = json.loads(user_ids)
except Exception:
continue # skip invalid JSON
if not isinstance(user_ids, list):
continue
rows = [
{
"id": str(uuid.uuid4()),
"group_id": group_id,
"user_id": uid,
"created_at": now,
"updated_at": now,
}
for uid in user_ids
]
if rows:
connection.execute(gm_table.insert(), rows)
# 4. Optionally drop the old column
with op.batch_alter_table("group") as batch:
batch.drop_column("user_ids")
def downgrade():
# Reverse: restore user_ids column
with op.batch_alter_table("group") as batch:
batch.add_column(sa.Column("user_ids", sa.JSON()))
connection = op.get_bind()
gm_table = sa.Table(
"group_member",
sa.MetaData(),
sa.Column("group_id", sa.Text()),
sa.Column("user_id", sa.Text()),
sa.Column("created_at", sa.BigInteger()),
sa.Column("updated_at", sa.BigInteger()),
)
group_table = sa.Table(
"group",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("user_ids", sa.JSON()),
)
# Build JSON arrays again
results = connection.execute(sa.select(group_table.c.id)).fetchall()
for (group_id,) in results:
members = connection.execute(
sa.select(gm_table.c.user_id).where(gm_table.c.group_id == group_id)
).fetchall()
member_ids = [m[0] for m in members]
connection.execute(
group_table.update()
.where(group_table.c.id == group_id)
.values(user_ids=member_ids)
)
# Drop the new table
op.drop_table("group_member")

View file

@ -1,80 +0,0 @@
"""Add oauth_session table
Revision ID: 38d63c18f30f
Revises: 3af16a1c9fb6
Create Date: 2025-09-08 14:19:59.583921
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "38d63c18f30f"
down_revision: Union[str, None] = "3af16a1c9fb6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Ensure 'id' column in 'user' table is unique and primary key (ForeignKey constraint)
inspector = sa.inspect(op.get_bind())
columns = inspector.get_columns("user")
pk_columns = inspector.get_pk_constraint("user")["constrained_columns"]
id_column = next((col for col in columns if col["name"] == "id"), None)
if id_column and not id_column.get("unique", False):
unique_constraints = inspector.get_unique_constraints("user")
unique_columns = {tuple(u["column_names"]) for u in unique_constraints}
with op.batch_alter_table("user") as batch_op:
# If primary key is wrong, drop it
if pk_columns and pk_columns != ["id"]:
batch_op.drop_constraint(
inspector.get_pk_constraint("user")["name"], type_="primary"
)
# Add unique constraint if missing
if ("id",) not in unique_columns:
batch_op.create_unique_constraint("uq_user_id", ["id"])
# Re-create correct primary key
batch_op.create_primary_key("pk_user_id", ["id"])
# Create oauth_session table
op.create_table(
"oauth_session",
sa.Column("id", sa.Text(), primary_key=True, nullable=False, unique=True),
sa.Column(
"user_id",
sa.Text(),
sa.ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("provider", sa.Text(), nullable=False),
sa.Column("token", sa.Text(), nullable=False),
sa.Column("expires_at", sa.BigInteger(), nullable=False),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
)
# Create indexes for better performance
op.create_index("idx_oauth_session_user_id", "oauth_session", ["user_id"])
op.create_index("idx_oauth_session_expires_at", "oauth_session", ["expires_at"])
op.create_index(
"idx_oauth_session_user_provider", "oauth_session", ["user_id", "provider"]
)
def downgrade() -> None:
# Drop indexes first
op.drop_index("idx_oauth_session_user_provider", table_name="oauth_session")
op.drop_index("idx_oauth_session_expires_at", table_name="oauth_session")
op.drop_index("idx_oauth_session_user_id", table_name="oauth_session")
# Drop the table
op.drop_table("oauth_session")

View file

@ -1,32 +0,0 @@
"""update user table
Revision ID: 3af16a1c9fb6
Revises: 018012973d35
Create Date: 2025-08-21 02:07:18.078283
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "3af16a1c9fb6"
down_revision: Union[str, None] = "018012973d35"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column("user", sa.Column("username", sa.String(length=50), nullable=True))
op.add_column("user", sa.Column("bio", sa.Text(), nullable=True))
op.add_column("user", sa.Column("gender", sa.Text(), nullable=True))
op.add_column("user", sa.Column("date_of_birth", sa.Date(), nullable=True))
def downgrade() -> None:
op.drop_column("user", "username")
op.drop_column("user", "bio")
op.drop_column("user", "gender")
op.drop_column("user", "date_of_birth")

View file

@ -1,169 +0,0 @@
"""Add knowledge_file table
Revision ID: 3e0e00844bb0
Revises: 90ef40d4714e
Create Date: 2025-12-02 06:54:19.401334
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy import inspect
import open_webui.internal.db
import time
import json
import uuid
# revision identifiers, used by Alembic.
revision: str = "3e0e00844bb0"
down_revision: Union[str, None] = "90ef40d4714e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"knowledge_file",
sa.Column("id", sa.Text(), primary_key=True),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column(
"knowledge_id",
sa.Text(),
sa.ForeignKey("knowledge.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"file_id",
sa.Text(),
sa.ForeignKey("file.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
# indexes
sa.Index("ix_knowledge_file_knowledge_id", "knowledge_id"),
sa.Index("ix_knowledge_file_file_id", "file_id"),
sa.Index("ix_knowledge_file_user_id", "user_id"),
# unique constraints
sa.UniqueConstraint(
"knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file"
), # prevent duplicate entries
)
connection = op.get_bind()
# 2. Read existing group with user_ids JSON column
knowledge_table = sa.Table(
"knowledge",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("user_id", sa.Text()),
sa.Column("data", sa.JSON()), # JSON stored as text in SQLite + PG
)
results = connection.execute(
sa.select(
knowledge_table.c.id, knowledge_table.c.user_id, knowledge_table.c.data
)
).fetchall()
# 3. Insert members into group_member table
kf_table = sa.Table(
"knowledge_file",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("user_id", sa.Text()),
sa.Column("knowledge_id", sa.Text()),
sa.Column("file_id", sa.Text()),
sa.Column("created_at", sa.BigInteger()),
sa.Column("updated_at", sa.BigInteger()),
)
file_table = sa.Table(
"file",
sa.MetaData(),
sa.Column("id", sa.Text()),
)
now = int(time.time())
for knowledge_id, user_id, data in results:
if not data:
continue
if isinstance(data, str):
try:
data = json.loads(data)
except Exception:
continue # skip invalid JSON
if not isinstance(data, dict):
continue
file_ids = data.get("file_ids", [])
for file_id in file_ids:
file_exists = connection.execute(
sa.select(file_table.c.id).where(file_table.c.id == file_id)
).fetchone()
if not file_exists:
continue # skip non-existing files
row = {
"id": str(uuid.uuid4()),
"user_id": user_id,
"knowledge_id": knowledge_id,
"file_id": file_id,
"created_at": now,
"updated_at": now,
}
connection.execute(kf_table.insert().values(**row))
with op.batch_alter_table("knowledge") as batch:
batch.drop_column("data")
def downgrade() -> None:
# 1. Add back the old data column
op.add_column("knowledge", sa.Column("data", sa.JSON(), nullable=True))
connection = op.get_bind()
# 2. Read knowledge_file entries and reconstruct data JSON
knowledge_table = sa.Table(
"knowledge",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("data", sa.JSON()),
)
kf_table = sa.Table(
"knowledge_file",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("knowledge_id", sa.Text()),
sa.Column("file_id", sa.Text()),
)
results = connection.execute(sa.select(knowledge_table.c.id)).fetchall()
for (knowledge_id,) in results:
file_ids = connection.execute(
sa.select(kf_table.c.file_id).where(kf_table.c.knowledge_id == knowledge_id)
).fetchall()
file_ids_list = [fid for (fid,) in file_ids]
data_json = {"file_ids": file_ids_list}
connection.execute(
knowledge_table.update()
.where(knowledge_table.c.id == knowledge_id)
.values(data=data_json)
)
# 3. Drop the knowledge_file table
op.drop_table("knowledge_file")

View file

@ -1,54 +0,0 @@
"""Add channel file table
Revision ID: 6283dc0e4d8d
Revises: 3e0e00844bb0
Create Date: 2025-12-10 15:11:39.424601
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
# revision identifiers, used by Alembic.
revision: str = "6283dc0e4d8d"
down_revision: Union[str, None] = "3e0e00844bb0"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"channel_file",
sa.Column("id", sa.Text(), primary_key=True),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column(
"channel_id",
sa.Text(),
sa.ForeignKey("channel.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"file_id",
sa.Text(),
sa.ForeignKey("file.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
# indexes
sa.Index("ix_channel_file_channel_id", "channel_id"),
sa.Index("ix_channel_file_file_id", "file_id"),
sa.Index("ix_channel_file_user_id", "user_id"),
# unique constraints
sa.UniqueConstraint(
"channel_id", "file_id", name="uq_channel_file_channel_file"
), # prevent duplicate entries
)
def downgrade() -> None:
op.drop_table("channel_file")

View file

@ -1,49 +0,0 @@
"""Update channel file and knowledge table
Revision ID: 81cc2ce44d79
Revises: 6283dc0e4d8d
Create Date: 2025-12-10 16:07:58.001282
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
# revision identifiers, used by Alembic.
revision: str = "81cc2ce44d79"
down_revision: Union[str, None] = "6283dc0e4d8d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add message_id column to channel_file table
with op.batch_alter_table("channel_file", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"message_id",
sa.Text(),
sa.ForeignKey(
"message.id", ondelete="CASCADE", name="fk_channel_file_message_id"
),
nullable=True,
)
)
# Add data column to knowledge table
with op.batch_alter_table("knowledge", schema=None) as batch_op:
batch_op.add_column(sa.Column("data", sa.JSON(), nullable=True))
def downgrade() -> None:
# Remove message_id column from channel_file table
with op.batch_alter_table("channel_file", schema=None) as batch_op:
batch_op.drop_column("message_id")
# Remove data column from knowledge table
with op.batch_alter_table("knowledge", schema=None) as batch_op:
batch_op.drop_column("data")

View file

@ -1,81 +0,0 @@
"""Update channel and channel members table
Revision ID: 90ef40d4714e
Revises: b10670c03dd5
Create Date: 2025-11-30 06:33:38.790341
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
# revision identifiers, used by Alembic.
revision: str = "90ef40d4714e"
down_revision: Union[str, None] = "b10670c03dd5"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Update 'channel' table
op.add_column("channel", sa.Column("is_private", sa.Boolean(), nullable=True))
op.add_column("channel", sa.Column("archived_at", sa.BigInteger(), nullable=True))
op.add_column("channel", sa.Column("archived_by", sa.Text(), nullable=True))
op.add_column("channel", sa.Column("deleted_at", sa.BigInteger(), nullable=True))
op.add_column("channel", sa.Column("deleted_by", sa.Text(), nullable=True))
op.add_column("channel", sa.Column("updated_by", sa.Text(), nullable=True))
# Update 'channel_member' table
op.add_column("channel_member", sa.Column("role", sa.Text(), nullable=True))
op.add_column("channel_member", sa.Column("invited_by", sa.Text(), nullable=True))
op.add_column(
"channel_member", sa.Column("invited_at", sa.BigInteger(), nullable=True)
)
# Create 'channel_webhook' table
op.create_table(
"channel_webhook",
sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column(
"channel_id",
sa.Text(),
sa.ForeignKey("channel.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("profile_image_url", sa.Text(), nullable=True),
sa.Column("token", sa.Text(), nullable=False),
sa.Column("last_used_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
)
pass
def downgrade() -> None:
# Downgrade 'channel' table
op.drop_column("channel", "is_private")
op.drop_column("channel", "archived_at")
op.drop_column("channel", "archived_by")
op.drop_column("channel", "deleted_at")
op.drop_column("channel", "deleted_by")
op.drop_column("channel", "updated_by")
# Downgrade 'channel_member' table
op.drop_column("channel_member", "role")
op.drop_column("channel_member", "invited_by")
op.drop_column("channel_member", "invited_at")
# Drop 'channel_webhook' table
op.drop_table("channel_webhook")
pass

View file

@ -1,34 +0,0 @@
"""Add reply_to_id column to message
Revision ID: a5c220713937
Revises: 38d63c18f30f
Create Date: 2025-09-27 02:24:18.058455
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "a5c220713937"
down_revision: Union[str, None] = "38d63c18f30f"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add 'reply_to_id' column to the 'message' table for replying to messages
op.add_column(
"message",
sa.Column("reply_to_id", sa.Text(), nullable=True),
)
pass
def downgrade() -> None:
# Remove 'reply_to_id' column from the 'message' table
op.drop_column("message", "reply_to_id")
pass

View file

@ -1,251 +0,0 @@
"""Update user table
Revision ID: b10670c03dd5
Revises: 2f1211949ecc
Create Date: 2025-11-28 04:55:31.737538
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
import json
import time
# revision identifiers, used by Alembic.
revision: str = "b10670c03dd5"
down_revision: Union[str, None] = "2f1211949ecc"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def _drop_sqlite_indexes_for_column(table_name, column_name, conn):
"""
SQLite requires manual removal of any indexes referencing a column
before ALTER TABLE ... DROP COLUMN can succeed.
"""
indexes = conn.execute(sa.text(f"PRAGMA index_list('{table_name}')")).fetchall()
for idx in indexes:
index_name = idx[1] # index name
# Get indexed columns
idx_info = conn.execute(
sa.text(f"PRAGMA index_info('{index_name}')")
).fetchall()
indexed_cols = [row[2] for row in idx_info] # col names
if column_name in indexed_cols:
conn.execute(sa.text(f"DROP INDEX IF EXISTS {index_name}"))
def _convert_column_to_json(table: str, column: str):
conn = op.get_bind()
dialect = conn.dialect.name
# SQLite cannot ALTER COLUMN → must recreate column
if dialect == "sqlite":
# 1. Add temporary column
op.add_column(table, sa.Column(f"{column}_json", sa.JSON(), nullable=True))
# 2. Load old data
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
for row in rows:
uid, raw = row
if raw is None:
parsed = None
else:
try:
parsed = json.loads(raw)
except Exception:
parsed = None # fallback safe behavior
conn.execute(
sa.text(f'UPDATE "{table}" SET {column}_json = :val WHERE id = :id'),
{"val": json.dumps(parsed) if parsed else None, "id": uid},
)
# 3. Drop old TEXT column
op.drop_column(table, column)
# 4. Rename new JSON column → original name
op.alter_column(table, f"{column}_json", new_column_name=column)
else:
# PostgreSQL supports direct CAST
op.alter_column(
table,
column,
type_=sa.JSON(),
postgresql_using=f"{column}::json",
)
def _convert_column_to_text(table: str, column: str):
conn = op.get_bind()
dialect = conn.dialect.name
if dialect == "sqlite":
op.add_column(table, sa.Column(f"{column}_text", sa.Text(), nullable=True))
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
for uid, raw in rows:
conn.execute(
sa.text(f'UPDATE "{table}" SET {column}_text = :val WHERE id = :id'),
{"val": json.dumps(raw) if raw else None, "id": uid},
)
op.drop_column(table, column)
op.alter_column(table, f"{column}_text", new_column_name=column)
else:
op.alter_column(
table,
column,
type_=sa.Text(),
postgresql_using=f"to_json({column})::text",
)
def upgrade() -> None:
op.add_column(
"user", sa.Column("profile_banner_image_url", sa.Text(), nullable=True)
)
op.add_column("user", sa.Column("timezone", sa.String(), nullable=True))
op.add_column("user", sa.Column("presence_state", sa.String(), nullable=True))
op.add_column("user", sa.Column("status_emoji", sa.String(), nullable=True))
op.add_column("user", sa.Column("status_message", sa.Text(), nullable=True))
op.add_column(
"user", sa.Column("status_expires_at", sa.BigInteger(), nullable=True)
)
op.add_column("user", sa.Column("oauth", sa.JSON(), nullable=True))
# Convert info (TEXT/JSONField) → JSON
_convert_column_to_json("user", "info")
# Convert settings (TEXT/JSONField) → JSON
_convert_column_to_json("user", "settings")
op.create_table(
"api_key",
sa.Column("id", sa.Text(), primary_key=True, unique=True),
sa.Column("user_id", sa.Text(), sa.ForeignKey("user.id", ondelete="CASCADE")),
sa.Column("key", sa.Text(), unique=True, nullable=False),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("expires_at", sa.BigInteger(), nullable=True),
sa.Column("last_used_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
)
conn = op.get_bind()
users = conn.execute(
sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL')
).fetchall()
for uid, oauth_sub in users:
if oauth_sub:
# Example formats supported:
# provider@sub
# plain sub (stored as {"oidc": {"sub": sub}})
if "@" in oauth_sub:
provider, sub = oauth_sub.split("@", 1)
else:
provider, sub = "oidc", oauth_sub
oauth_json = json.dumps({provider: {"sub": sub}})
conn.execute(
sa.text('UPDATE "user" SET oauth = :oauth WHERE id = :id'),
{"oauth": oauth_json, "id": uid},
)
users_with_keys = conn.execute(
sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL')
).fetchall()
now = int(time.time())
for uid, api_key in users_with_keys:
if api_key:
conn.execute(
sa.text(
"""
INSERT INTO api_key (id, user_id, key, created_at, updated_at)
VALUES (:id, :user_id, :key, :created_at, :updated_at)
"""
),
{
"id": f"key_{uid}",
"user_id": uid,
"key": api_key,
"created_at": now,
"updated_at": now,
},
)
if conn.dialect.name == "sqlite":
_drop_sqlite_indexes_for_column("user", "api_key", conn)
_drop_sqlite_indexes_for_column("user", "oauth_sub", conn)
with op.batch_alter_table("user") as batch_op:
batch_op.drop_column("api_key")
batch_op.drop_column("oauth_sub")
def downgrade() -> None:
# --- 1. Restore old oauth_sub column ---
op.add_column("user", sa.Column("oauth_sub", sa.Text(), nullable=True))
conn = op.get_bind()
users = conn.execute(
sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL')
).fetchall()
for uid, oauth in users:
try:
data = json.loads(oauth)
provider = list(data.keys())[0]
sub = data[provider].get("sub")
oauth_sub = f"{provider}@{sub}"
except Exception:
oauth_sub = None
conn.execute(
sa.text('UPDATE "user" SET oauth_sub = :oauth_sub WHERE id = :id'),
{"oauth_sub": oauth_sub, "id": uid},
)
op.drop_column("user", "oauth")
# --- 2. Restore api_key field ---
op.add_column("user", sa.Column("api_key", sa.String(), nullable=True))
# Restore values from api_key
keys = conn.execute(sa.text("SELECT user_id, key FROM api_key")).fetchall()
for uid, key in keys:
conn.execute(
sa.text('UPDATE "user" SET api_key = :key WHERE id = :id'),
{"key": key, "id": uid},
)
# Drop new table
op.drop_table("api_key")
with op.batch_alter_table("user") as batch_op:
batch_op.drop_column("profile_banner_image_url")
batch_op.drop_column("timezone")
batch_op.drop_column("presence_state")
batch_op.drop_column("status_emoji")
batch_op.drop_column("status_message")
batch_op.drop_column("status_expires_at")
# Convert info (JSON) → TEXT
_convert_column_to_text("user", "info")
# Convert settings (JSON) → TEXT
_convert_column_to_text("user", "settings")

View file

@ -1,57 +0,0 @@
"""Add chat_file table
Revision ID: c440947495f3
Revises: 81cc2ce44d79
Create Date: 2025-12-21 20:27:41.694897
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "c440947495f3"
down_revision: Union[str, None] = "81cc2ce44d79"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"chat_file",
sa.Column("id", sa.Text(), primary_key=True),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column(
"chat_id",
sa.Text(),
sa.ForeignKey("chat.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"file_id",
sa.Text(),
sa.ForeignKey("file.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("message_id", sa.Text(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
# indexes
sa.Index("ix_chat_file_chat_id", "chat_id"),
sa.Index("ix_chat_file_file_id", "file_id"),
sa.Index("ix_chat_file_message_id", "message_id"),
sa.Index("ix_chat_file_user_id", "user_id"),
# unique constraints
sa.UniqueConstraint(
"chat_id", "file_id", name="uq_chat_file_chat_file"
), # prevent duplicate entries
)
pass
def downgrade() -> None:
op.drop_table("chat_file")
pass

View file

@ -1,23 +0,0 @@
"""Update folder table data
Revision ID: d31026856c01
Revises: 9f0c9cd09105
Create Date: 2025-07-13 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "d31026856c01"
down_revision = "9f0c9cd09105"
branch_labels = None
depends_on = None
def upgrade():
op.add_column("folder", sa.Column("data", sa.JSON(), nullable=True))
def downgrade():
op.drop_column("folder", "data")

View file

@ -3,11 +3,14 @@ import uuid
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, get_db
from open_webui.models.users import UserModel, UserProfileImageResponse, Users from open_webui.models.users import UserModel, Users
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import Boolean, Column, String, Text from sqlalchemy import Boolean, Column, String, Text
from open_webui.utils.auth import verify_password
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
# DB MODEL # DB MODEL
@ -17,7 +20,7 @@ log = logging.getLogger(__name__)
class Auth(Base): class Auth(Base):
__tablename__ = "auth" __tablename__ = "auth"
id = Column(String, primary_key=True, unique=True) id = Column(String, primary_key=True)
email = Column(String) email = Column(String)
password = Column(Text) password = Column(Text)
active = Column(Boolean) active = Column(Boolean)
@ -44,7 +47,15 @@ class ApiKey(BaseModel):
api_key: Optional[str] = None api_key: Optional[str] = None
class SigninResponse(Token, UserProfileImageResponse): class UserResponse(BaseModel):
id: str
email: str
name: str
role: str
profile_image_url: str
class SigninResponse(Token, UserResponse):
pass pass
@ -62,6 +73,11 @@ class ProfileImageUrlForm(BaseModel):
profile_image_url: str profile_image_url: str
class UpdateProfileForm(BaseModel):
profile_image_url: str
name: str
class UpdatePasswordForm(BaseModel): class UpdatePasswordForm(BaseModel):
password: str password: str
new_password: str new_password: str
@ -86,7 +102,7 @@ class AuthsTable:
name: str, name: str,
profile_image_url: str = "/user.png", profile_image_url: str = "/user.png",
role: str = "pending", role: str = "pending",
oauth: Optional[dict] = None, oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_db() as db: with get_db() as db:
log.info("insert_new_auth") log.info("insert_new_auth")
@ -100,7 +116,7 @@ class AuthsTable:
db.add(result) db.add(result)
user = Users.insert_new_user( user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth=oauth id, name, email, profile_image_url, role, oauth_sub
) )
db.commit() db.commit()
@ -111,9 +127,7 @@ class AuthsTable:
else: else:
return None return None
def authenticate_user( def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
self, email: str, verify_password: callable
) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}") log.info(f"authenticate_user: {email}")
user = Users.get_user_by_email(email) user = Users.get_user_by_email(email)
@ -124,7 +138,7 @@ class AuthsTable:
with get_db() as db: with get_db() as db:
auth = db.query(Auth).filter_by(id=user.id, active=True).first() auth = db.query(Auth).filter_by(id=user.id, active=True).first()
if auth: if auth:
if verify_password(auth.password): if verify_password(password, auth.password):
return user return user
else: else:
return None return None

View file

@ -4,24 +4,10 @@ import uuid
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, get_db
from open_webui.models.groups import Groups from open_webui.utils.access_control import has_access
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import (
BigInteger,
Boolean,
Column,
ForeignKey,
String,
Text,
JSON,
UniqueConstraint,
case,
cast,
)
from sqlalchemy import or_, func, select, and_, text from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists from sqlalchemy.sql import exists
@ -33,30 +19,19 @@ from sqlalchemy.sql import exists
class Channel(Base): class Channel(Base):
__tablename__ = "channel" __tablename__ = "channel"
id = Column(Text, primary_key=True, unique=True) id = Column(Text, primary_key=True)
user_id = Column(Text) user_id = Column(Text)
type = Column(Text, nullable=True) type = Column(Text, nullable=True)
name = Column(Text) name = Column(Text)
description = Column(Text, nullable=True) description = Column(Text, nullable=True)
# Used to indicate if the channel is private (for 'group' type channels)
is_private = Column(Boolean, nullable=True)
data = Column(JSON, nullable=True) data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True) meta = Column(JSON, nullable=True)
access_control = Column(JSON, nullable=True) access_control = Column(JSON, nullable=True)
created_at = Column(BigInteger) created_at = Column(BigInteger)
updated_at = Column(BigInteger) updated_at = Column(BigInteger)
updated_by = Column(Text, nullable=True)
archived_at = Column(BigInteger, nullable=True)
archived_by = Column(Text, nullable=True)
deleted_at = Column(BigInteger, nullable=True)
deleted_by = Column(Text, nullable=True)
class ChannelModel(BaseModel): class ChannelModel(BaseModel):
@ -64,157 +39,17 @@ class ChannelModel(BaseModel):
id: str id: str
user_id: str user_id: str
type: Optional[str] = None type: Optional[str] = None
name: str name: str
description: Optional[str] = None description: Optional[str] = None
is_private: Optional[bool] = None
data: Optional[dict] = None data: Optional[dict] = None
meta: Optional[dict] = None meta: Optional[dict] = None
access_control: Optional[dict] = None access_control: Optional[dict] = None
created_at: int # timestamp in epoch (time_ns) created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
updated_at: int # timestamp in epoch (time_ns)
updated_by: Optional[str] = None
archived_at: Optional[int] = None # timestamp in epoch (time_ns)
archived_by: Optional[str] = None
deleted_at: Optional[int] = None # timestamp in epoch (time_ns)
deleted_by: Optional[str] = None
class ChannelMember(Base):
__tablename__ = "channel_member"
id = Column(Text, primary_key=True, unique=True)
channel_id = Column(Text, nullable=False)
user_id = Column(Text, nullable=False)
role = Column(Text, nullable=True)
status = Column(Text, nullable=True)
is_active = Column(Boolean, nullable=False, default=True)
is_channel_muted = Column(Boolean, nullable=False, default=False)
is_channel_pinned = Column(Boolean, nullable=False, default=False)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
invited_at = Column(BigInteger, nullable=True)
invited_by = Column(Text, nullable=True)
joined_at = Column(BigInteger)
left_at = Column(BigInteger, nullable=True)
last_read_at = Column(BigInteger, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class ChannelMemberModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
channel_id: str
user_id: str
role: Optional[str] = None
status: Optional[str] = None
is_active: bool = True
is_channel_muted: bool = False
is_channel_pinned: bool = False
data: Optional[dict] = None
meta: Optional[dict] = None
invited_at: Optional[int] = None # timestamp in epoch (time_ns)
invited_by: Optional[str] = None
joined_at: Optional[int] = None # timestamp in epoch (time_ns)
left_at: Optional[int] = None # timestamp in epoch (time_ns)
last_read_at: Optional[int] = None # timestamp in epoch (time_ns)
created_at: Optional[int] = None # timestamp in epoch (time_ns)
updated_at: Optional[int] = None # timestamp in epoch (time_ns)
class ChannelFile(Base):
__tablename__ = "channel_file"
id = Column(Text, unique=True, primary_key=True)
user_id = Column(Text, nullable=False)
channel_id = Column(
Text, ForeignKey("channel.id", ondelete="CASCADE"), nullable=False
)
message_id = Column(
Text, ForeignKey("message.id", ondelete="CASCADE"), nullable=True
)
file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
__table_args__ = (
UniqueConstraint("channel_id", "file_id", name="uq_channel_file_channel_file"),
)
class ChannelFileModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
channel_id: str
file_id: str
user_id: str
created_at: int # timestamp in epoch (time_ns)
updated_at: int # timestamp in epoch (time_ns)
class ChannelWebhook(Base):
__tablename__ = "channel_webhook"
id = Column(Text, primary_key=True, unique=True)
channel_id = Column(Text, nullable=False)
user_id = Column(Text, nullable=False)
name = Column(Text, nullable=False)
profile_image_url = Column(Text, nullable=True)
token = Column(Text, nullable=False)
last_used_at = Column(BigInteger, nullable=True)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
class ChannelWebhookModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
channel_id: str
user_id: str
name: str
profile_image_url: Optional[str] = None
token: str
last_used_at: Optional[int] = None # timestamp in epoch (time_ns)
created_at: int # timestamp in epoch (time_ns)
updated_at: int # timestamp in epoch (time_ns)
#################### ####################
@ -222,95 +57,23 @@ class ChannelWebhookModel(BaseModel):
#################### ####################
class ChannelResponse(ChannelModel):
is_manager: bool = False
write_access: bool = False
user_count: Optional[int] = None
class ChannelForm(BaseModel): class ChannelForm(BaseModel):
name: str = "" name: str
description: Optional[str] = None description: Optional[str] = None
is_private: Optional[bool] = None
data: Optional[dict] = None data: Optional[dict] = None
meta: Optional[dict] = None meta: Optional[dict] = None
access_control: Optional[dict] = None access_control: Optional[dict] = None
group_ids: Optional[list[str]] = None
user_ids: Optional[list[str]] = None
class CreateChannelForm(ChannelForm):
type: Optional[str] = None
class ChannelTable: class ChannelTable:
def _collect_unique_user_ids(
self,
invited_by: str,
user_ids: Optional[list[str]] = None,
group_ids: Optional[list[str]] = None,
) -> set[str]:
"""
Collect unique user ids from:
- invited_by
- user_ids
- each group in group_ids
Returns a set for efficient SQL diffing.
"""
users = set(user_ids or [])
users.add(invited_by)
for group_id in group_ids or []:
users.update(Groups.get_group_user_ids_by_id(group_id))
return users
def _create_membership_models(
self,
channel_id: str,
invited_by: str,
user_ids: set[str],
) -> list[ChannelMember]:
"""
Takes a set of NEW user IDs (already filtered to exclude existing members).
Returns ORM ChannelMember objects to be added.
"""
now = int(time.time_ns())
memberships = []
for uid in user_ids:
model = ChannelMemberModel(
**{
"id": str(uuid.uuid4()),
"channel_id": channel_id,
"user_id": uid,
"status": "joined",
"is_active": True,
"is_channel_muted": False,
"is_channel_pinned": False,
"invited_at": now,
"invited_by": invited_by,
"joined_at": now,
"left_at": None,
"last_read_at": now,
"created_at": now,
"updated_at": now,
}
)
memberships.append(ChannelMember(**model.model_dump()))
return memberships
def insert_new_channel( def insert_new_channel(
self, form_data: CreateChannelForm, user_id: str self, type: Optional[str], form_data: ChannelForm, user_id: str
) -> Optional[ChannelModel]: ) -> Optional[ChannelModel]:
with get_db() as db: with get_db() as db:
channel = ChannelModel( channel = ChannelModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
"type": form_data.type if form_data.type else None, "type": type,
"name": form_data.name.lower(), "name": form_data.name.lower(),
"id": str(uuid.uuid4()), "id": str(uuid.uuid4()),
"user_id": user_id, "user_id": user_id,
@ -318,21 +81,9 @@ class ChannelTable:
"updated_at": int(time.time_ns()), "updated_at": int(time.time_ns()),
} }
) )
new_channel = Channel(**channel.model_dump()) new_channel = Channel(**channel.model_dump())
if form_data.type in ["group", "dm"]:
users = self._collect_unique_user_ids(
invited_by=user_id,
user_ids=form_data.user_ids,
group_ids=form_data.group_ids,
)
memberships = self._create_membership_models(
channel_id=new_channel.id,
invited_by=user_id,
user_ids=users,
)
db.add_all(memberships)
db.add(new_channel) db.add(new_channel)
db.commit() db.commit()
return channel return channel
@ -342,481 +93,22 @@ class ChannelTable:
channels = db.query(Channel).all() channels = db.query(Channel).all()
return [ChannelModel.model_validate(channel) for channel in channels] return [ChannelModel.model_validate(channel) for channel in channels]
def _has_permission(self, db, query, filter: dict, permission: str = "read"): def get_channels_by_user_id(
group_ids = filter.get("group_ids", []) self, user_id: str, permission: str = "read"
user_id = filter.get("user_id") ) -> list[ChannelModel]:
channels = self.get_channels()
dialect_name = db.bind.dialect.name return [
channel
# Public access for channel in channels
conditions = [] if channel.user_id == user_id
if group_ids or user_id: or has_access(user_id, permission, channel.access_control)
conditions.extend( ]
[
Channel.access_control.is_(None),
cast(Channel.access_control, String) == "null",
]
)
# User-level permission
if user_id:
conditions.append(Channel.user_id == user_id)
# Group-level permission
if group_ids:
group_conditions = []
for gid in group_ids:
if dialect_name == "sqlite":
group_conditions.append(
Channel.access_control[permission]["group_ids"].contains([gid])
)
elif dialect_name == "postgresql":
group_conditions.append(
cast(
Channel.access_control[permission]["group_ids"],
JSONB,
).contains([gid])
)
conditions.append(or_(*group_conditions))
if conditions:
query = query.filter(or_(*conditions))
return query
def get_channels_by_user_id(self, user_id: str) -> list[ChannelModel]:
with get_db() as db:
user_group_ids = [
group.id for group in Groups.get_groups_by_member_id(user_id)
]
membership_channels = (
db.query(Channel)
.join(ChannelMember, Channel.id == ChannelMember.channel_id)
.filter(
Channel.deleted_at.is_(None),
Channel.archived_at.is_(None),
Channel.type.in_(["group", "dm"]),
ChannelMember.user_id == user_id,
ChannelMember.is_active.is_(True),
)
.all()
)
query = db.query(Channel).filter(
Channel.deleted_at.is_(None),
Channel.archived_at.is_(None),
or_(
Channel.type.is_(None), # True NULL/None
Channel.type == "", # Empty string
and_(Channel.type != "group", Channel.type != "dm"),
),
)
query = self._has_permission(
db, query, {"user_id": user_id, "group_ids": user_group_ids}
)
standard_channels = query.all()
all_channels = membership_channels + standard_channels
return [ChannelModel.model_validate(c) for c in all_channels]
def get_dm_channel_by_user_ids(self, user_ids: list[str]) -> Optional[ChannelModel]:
with get_db() as db:
# Ensure uniqueness in case a list with duplicates is passed
unique_user_ids = list(set(user_ids))
match_count = func.sum(
case(
(ChannelMember.user_id.in_(unique_user_ids), 1),
else_=0,
)
)
subquery = (
db.query(ChannelMember.channel_id)
.group_by(ChannelMember.channel_id)
# 1. Channel must have exactly len(user_ids) members
.having(func.count(ChannelMember.user_id) == len(unique_user_ids))
# 2. All those members must be in unique_user_ids
.having(match_count == len(unique_user_ids))
.subquery()
)
channel = (
db.query(Channel)
.filter(
Channel.id.in_(subquery),
Channel.type == "dm",
)
.first()
)
return ChannelModel.model_validate(channel) if channel else None
def add_members_to_channel(
self,
channel_id: str,
invited_by: str,
user_ids: Optional[list[str]] = None,
group_ids: Optional[list[str]] = None,
) -> list[ChannelMemberModel]:
with get_db() as db:
# 1. Collect all user_ids including groups + inviter
requested_users = self._collect_unique_user_ids(
invited_by, user_ids, group_ids
)
existing_users = {
row.user_id
for row in db.query(ChannelMember.user_id)
.filter(ChannelMember.channel_id == channel_id)
.all()
}
new_user_ids = requested_users - existing_users
if not new_user_ids:
return [] # Nothing to add
new_memberships = self._create_membership_models(
channel_id, invited_by, new_user_ids
)
db.add_all(new_memberships)
db.commit()
return [
ChannelMemberModel.model_validate(membership)
for membership in new_memberships
]
def remove_members_from_channel(
self,
channel_id: str,
user_ids: list[str],
) -> int:
with get_db() as db:
result = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id.in_(user_ids),
)
.delete(synchronize_session=False)
)
db.commit()
return result # number of rows deleted
def is_user_channel_manager(self, channel_id: str, user_id: str) -> bool:
with get_db() as db:
# Check if the user is the creator of the channel
# or has a 'manager' role in ChannelMember
channel = db.query(Channel).filter(Channel.id == channel_id).first()
if channel and channel.user_id == user_id:
return True
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
ChannelMember.role == "manager",
)
.first()
)
return membership is not None
def join_channel(
self, channel_id: str, user_id: str
) -> Optional[ChannelMemberModel]:
with get_db() as db:
# Check if the membership already exists
existing_membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
)
.first()
)
if existing_membership:
return ChannelMemberModel.model_validate(existing_membership)
# Create new membership
channel_member = ChannelMemberModel(
**{
"id": str(uuid.uuid4()),
"channel_id": channel_id,
"user_id": user_id,
"status": "joined",
"is_active": True,
"is_channel_muted": False,
"is_channel_pinned": False,
"joined_at": int(time.time_ns()),
"left_at": None,
"last_read_at": int(time.time_ns()),
"created_at": int(time.time_ns()),
"updated_at": int(time.time_ns()),
}
)
new_membership = ChannelMember(**channel_member.model_dump())
db.add(new_membership)
db.commit()
return channel_member
def leave_channel(self, channel_id: str, user_id: str) -> bool:
with get_db() as db:
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
)
.first()
)
if not membership:
return False
membership.status = "left"
membership.is_active = False
membership.left_at = int(time.time_ns())
membership.updated_at = int(time.time_ns())
db.commit()
return True
def get_member_by_channel_and_user_id(
self, channel_id: str, user_id: str
) -> Optional[ChannelMemberModel]:
with get_db() as db:
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
)
.first()
)
return ChannelMemberModel.model_validate(membership) if membership else None
def get_members_by_channel_id(self, channel_id: str) -> list[ChannelMemberModel]:
with get_db() as db:
memberships = (
db.query(ChannelMember)
.filter(ChannelMember.channel_id == channel_id)
.all()
)
return [
ChannelMemberModel.model_validate(membership)
for membership in memberships
]
def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool) -> bool:
with get_db() as db:
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
)
.first()
)
if not membership:
return False
membership.is_channel_pinned = is_pinned
membership.updated_at = int(time.time_ns())
db.commit()
return True
def update_member_last_read_at(self, channel_id: str, user_id: str) -> bool:
with get_db() as db:
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
)
.first()
)
if not membership:
return False
membership.last_read_at = int(time.time_ns())
membership.updated_at = int(time.time_ns())
db.commit()
return True
def update_member_active_status(
self, channel_id: str, user_id: str, is_active: bool
) -> bool:
with get_db() as db:
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
)
.first()
)
if not membership:
return False
membership.is_active = is_active
membership.updated_at = int(time.time_ns())
db.commit()
return True
def is_user_channel_member(self, channel_id: str, user_id: str) -> bool:
with get_db() as db:
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
)
.first()
)
return membership is not None
def get_channel_by_id(self, id: str) -> Optional[ChannelModel]: def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
with get_db() as db: with get_db() as db:
channel = db.query(Channel).filter(Channel.id == id).first() channel = db.query(Channel).filter(Channel.id == id).first()
return ChannelModel.model_validate(channel) if channel else None return ChannelModel.model_validate(channel) if channel else None
def get_channels_by_file_id(self, file_id: str) -> list[ChannelModel]:
with get_db() as db:
channel_files = (
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
)
channel_ids = [cf.channel_id for cf in channel_files]
channels = db.query(Channel).filter(Channel.id.in_(channel_ids)).all()
return [ChannelModel.model_validate(channel) for channel in channels]
def get_channels_by_file_id_and_user_id(
self, file_id: str, user_id: str
) -> list[ChannelModel]:
with get_db() as db:
# 1. Determine which channels have this file
channel_file_rows = (
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
)
channel_ids = [row.channel_id for row in channel_file_rows]
if not channel_ids:
return []
# 2. Load all channel rows that still exist
channels = (
db.query(Channel)
.filter(
Channel.id.in_(channel_ids),
Channel.deleted_at.is_(None),
Channel.archived_at.is_(None),
)
.all()
)
if not channels:
return []
# Preload user's group membership
user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id)]
allowed_channels = []
for channel in channels:
# --- Case A: group or dm => user must be an active member ---
if channel.type in ["group", "dm"]:
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel.id,
ChannelMember.user_id == user_id,
ChannelMember.is_active.is_(True),
)
.first()
)
if membership:
allowed_channels.append(ChannelModel.model_validate(channel))
continue
# --- Case B: standard channel => rely on ACL permissions ---
query = db.query(Channel).filter(Channel.id == channel.id)
query = self._has_permission(
db,
query,
{"user_id": user_id, "group_ids": user_group_ids},
permission="read",
)
allowed = query.first()
if allowed:
allowed_channels.append(ChannelModel.model_validate(allowed))
return allowed_channels
def get_channel_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[ChannelModel]:
with get_db() as db:
# Fetch the channel
channel: Channel = (
db.query(Channel)
.filter(
Channel.id == id,
Channel.deleted_at.is_(None),
Channel.archived_at.is_(None),
)
.first()
)
if not channel:
return None
# If the channel is a group or dm, read access requires membership (active)
if channel.type in ["group", "dm"]:
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == id,
ChannelMember.user_id == user_id,
ChannelMember.is_active.is_(True),
)
.first()
)
if membership:
return ChannelModel.model_validate(channel)
else:
return None
# For channels that are NOT group/dm, fall back to ACL-based read access
query = db.query(Channel).filter(Channel.id == id)
# Determine user groups
user_group_ids = [
group.id for group in Groups.get_groups_by_member_id(user_id)
]
# Apply ACL rules
query = self._has_permission(
db,
query,
{"user_id": user_id, "group_ids": user_group_ids},
permission="read",
)
channel_allowed = query.first()
return (
ChannelModel.model_validate(channel_allowed)
if channel_allowed
else None
)
def update_channel_by_id( def update_channel_by_id(
self, id: str, form_data: ChannelForm self, id: str, form_data: ChannelForm
) -> Optional[ChannelModel]: ) -> Optional[ChannelModel]:
@ -826,77 +118,14 @@ class ChannelTable:
return None return None
channel.name = form_data.name channel.name = form_data.name
channel.description = form_data.description
channel.is_private = form_data.is_private
channel.data = form_data.data channel.data = form_data.data
channel.meta = form_data.meta channel.meta = form_data.meta
channel.access_control = form_data.access_control channel.access_control = form_data.access_control
channel.updated_at = int(time.time_ns()) channel.updated_at = int(time.time_ns())
db.commit() db.commit()
return ChannelModel.model_validate(channel) if channel else None return ChannelModel.model_validate(channel) if channel else None
def add_file_to_channel_by_id(
self, channel_id: str, file_id: str, user_id: str
) -> Optional[ChannelFileModel]:
with get_db() as db:
channel_file = ChannelFileModel(
**{
"id": str(uuid.uuid4()),
"channel_id": channel_id,
"file_id": file_id,
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
result = ChannelFile(**channel_file.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return ChannelFileModel.model_validate(result)
else:
return None
except Exception:
return None
def set_file_message_id_in_channel_by_id(
self, channel_id: str, file_id: str, message_id: str
) -> bool:
try:
with get_db() as db:
channel_file = (
db.query(ChannelFile)
.filter_by(channel_id=channel_id, file_id=file_id)
.first()
)
if not channel_file:
return False
channel_file.message_id = message_id
channel_file.updated_at = int(time.time())
db.commit()
return True
except Exception:
return False
def remove_file_from_channel_by_id(self, channel_id: str, file_id: str) -> bool:
try:
with get_db() as db:
db.query(ChannelFile).filter_by(
channel_id=channel_id, file_id=file_id
).delete()
db.commit()
return True
except Exception:
return False
def delete_channel_by_id(self, id: str): def delete_channel_by_id(self, id: str):
with get_db() as db: with get_db() as db:
db.query(Channel).filter(Channel.id == id).delete() db.query(Channel).filter(Channel.id == id).delete()

View file

@ -6,36 +6,25 @@ from typing import Optional
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags from open_webui.models.tags import TagModel, Tag, Tags
from open_webui.models.folders import Folders from open_webui.env import SRC_LOG_LEVELS
from open_webui.utils.misc import sanitize_data_for_db, sanitize_text_for_db
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import ( from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
BigInteger,
Boolean,
Column,
ForeignKey,
String,
Text,
JSON,
Index,
UniqueConstraint,
)
from sqlalchemy import or_, func, select, and_, text from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists from sqlalchemy.sql import exists
from sqlalchemy.sql.expression import bindparam
#################### ####################
# Chat DB Schema # Chat DB Schema
#################### ####################
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Chat(Base): class Chat(Base):
__tablename__ = "chat" __tablename__ = "chat"
id = Column(String, primary_key=True, unique=True) id = Column(String, primary_key=True)
user_id = Column(String) user_id = Column(String)
title = Column(Text) title = Column(Text)
chat = Column(JSON) chat = Column(JSON)
@ -50,20 +39,6 @@ class Chat(Base):
meta = Column(JSON, server_default="{}") meta = Column(JSON, server_default="{}")
folder_id = Column(Text, nullable=True) folder_id = Column(Text, nullable=True)
__table_args__ = (
# Performance indexes for common queries
# WHERE folder_id = ...
Index("folder_id_idx", "folder_id"),
# WHERE user_id = ... AND pinned = ...
Index("user_id_pinned_idx", "user_id", "pinned"),
# WHERE user_id = ... AND archived = ...
Index("user_id_archived_idx", "user_id", "archived"),
# WHERE user_id = ... ORDER BY updated_at DESC
Index("updated_at_user_id_idx", "updated_at", "user_id"),
# WHERE folder_id = ... AND user_id = ...
Index("folder_id_user_id_idx", "folder_id", "user_id"),
)
class ChatModel(BaseModel): class ChatModel(BaseModel):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@ -84,38 +59,6 @@ class ChatModel(BaseModel):
folder_id: Optional[str] = None folder_id: Optional[str] = None
class ChatFile(Base):
__tablename__ = "chat_file"
id = Column(Text, unique=True, primary_key=True)
user_id = Column(Text, nullable=False)
chat_id = Column(Text, ForeignKey("chat.id", ondelete="CASCADE"), nullable=False)
message_id = Column(Text, nullable=True)
file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
__table_args__ = (
UniqueConstraint("chat_id", "file_id", name="uq_chat_file_chat_file"),
)
class ChatFileModel(BaseModel):
id: str
user_id: str
chat_id: str
message_id: Optional[str] = None
file_id: str
created_at: int
updated_at: int
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
#################### ####################
@ -123,18 +66,12 @@ class ChatFileModel(BaseModel):
class ChatForm(BaseModel): class ChatForm(BaseModel):
chat: dict chat: dict
folder_id: Optional[str] = None
class ChatImportForm(ChatForm): class ChatImportForm(ChatForm):
meta: Optional[dict] = {} meta: Optional[dict] = {}
pinned: Optional[bool] = False pinned: Optional[bool] = False
created_at: Optional[int] = None folder_id: Optional[str] = None
updated_at: Optional[int] = None
class ChatsImportForm(BaseModel):
chats: list[ChatImportForm]
class ChatTitleMessagesForm(BaseModel): class ChatTitleMessagesForm(BaseModel):
@ -167,77 +104,7 @@ class ChatTitleIdResponse(BaseModel):
created_at: int created_at: int
class ChatListResponse(BaseModel):
items: list[ChatModel]
total: int
class ChatUsageStatsResponse(BaseModel):
id: str # chat id
models: dict = {} # models used in the chat with their usage counts
message_count: int # number of messages in the chat
history_models: dict = {} # models used in the chat history with their usage counts
history_message_count: int # number of messages in the chat history
history_user_message_count: int # number of user messages in the chat history
history_assistant_message_count: (
int # number of assistant messages in the chat history
)
average_response_time: (
float # average response time of assistant messages in seconds
)
average_user_message_content_length: (
float # average length of user message contents
)
average_assistant_message_content_length: (
float # average length of assistant message contents
)
tags: list[str] = [] # tags associated with the chat
last_message_at: int # timestamp of the last message
updated_at: int
created_at: int
model_config = ConfigDict(extra="allow")
class ChatUsageStatsListResponse(BaseModel):
items: list[ChatUsageStatsResponse]
total: int
model_config = ConfigDict(extra="allow")
class ChatTable: class ChatTable:
def _clean_null_bytes(self, obj):
"""Recursively remove null bytes from strings in dict/list structures."""
return sanitize_data_for_db(obj)
def _sanitize_chat_row(self, chat_item):
"""
Clean a Chat SQLAlchemy model's title + chat JSON,
and return True if anything changed.
"""
changed = False
# Clean title
if chat_item.title:
cleaned = self._clean_null_bytes(chat_item.title)
if cleaned != chat_item.title:
chat_item.title = cleaned
changed = True
# Clean JSON
if chat_item.chat:
cleaned = self._clean_null_bytes(chat_item.chat)
if cleaned != chat_item.chat:
chat_item.chat = cleaned
changed = True
return changed
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
with get_db() as db: with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
@ -245,76 +112,59 @@ class ChatTable:
**{ **{
"id": id, "id": id,
"user_id": user_id, "user_id": user_id,
"title": self._clean_null_bytes( "title": (
form_data.chat["title"] form_data.chat["title"]
if "title" in form_data.chat if "title" in form_data.chat
else "New Chat" else "New Chat"
), ),
"chat": self._clean_null_bytes(form_data.chat), "chat": form_data.chat,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
result = Chat(**chat.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def import_chat(
self, user_id: str, form_data: ChatImportForm
) -> Optional[ChatModel]:
with get_db() as db:
id = str(uuid.uuid4())
chat = ChatModel(
**{
"id": id,
"user_id": user_id,
"title": (
form_data.chat["title"]
if "title" in form_data.chat
else "New Chat"
),
"chat": form_data.chat,
"meta": form_data.meta,
"pinned": form_data.pinned,
"folder_id": form_data.folder_id, "folder_id": form_data.folder_id,
"created_at": int(time.time()), "created_at": int(time.time()),
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
chat_item = Chat(**chat.model_dump()) result = Chat(**chat.model_dump())
db.add(chat_item) db.add(result)
db.commit() db.commit()
db.refresh(chat_item) db.refresh(result)
return ChatModel.model_validate(chat_item) if chat_item else None return ChatModel.model_validate(result) if result else None
def _chat_import_form_to_chat_model(
self, user_id: str, form_data: ChatImportForm
) -> ChatModel:
id = str(uuid.uuid4())
chat = ChatModel(
**{
"id": id,
"user_id": user_id,
"title": self._clean_null_bytes(
form_data.chat["title"] if "title" in form_data.chat else "New Chat"
),
"chat": self._clean_null_bytes(form_data.chat),
"meta": form_data.meta,
"pinned": form_data.pinned,
"folder_id": form_data.folder_id,
"created_at": (
form_data.created_at if form_data.created_at else int(time.time())
),
"updated_at": (
form_data.updated_at if form_data.updated_at else int(time.time())
),
}
)
return chat
def import_chats(
self, user_id: str, chat_import_forms: list[ChatImportForm]
) -> list[ChatModel]:
with get_db() as db:
chats = []
for form_data in chat_import_forms:
chat = self._chat_import_form_to_chat_model(user_id, form_data)
chats.append(Chat(**chat.model_dump()))
db.add_all(chats)
db.commit()
return [ChatModel.model_validate(chat) for chat in chats]
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db() as db:
chat_item = db.get(Chat, id) chat_item = db.get(Chat, id)
chat_item.chat = self._clean_null_bytes(chat) chat_item.chat = chat
chat_item.title = ( chat_item.title = chat["title"] if "title" in chat else "New Chat"
self._clean_null_bytes(chat["title"])
if "title" in chat
else "New Chat"
)
chat_item.updated_at = int(time.time()) chat_item.updated_at = int(time.time())
db.commit() db.commit()
db.refresh(chat_item) db.refresh(chat_item)
@ -359,7 +209,7 @@ class ChatTable:
return chat.chat.get("title", "New Chat") return chat.chat.get("title", "New Chat")
def get_messages_map_by_chat_id(self, id: str) -> Optional[dict]: def get_messages_by_chat_id(self, id: str) -> Optional[dict]:
chat = self.get_chat_by_id(id) chat = self.get_chat_by_id(id)
if chat is None: if chat is None:
return None return None
@ -382,10 +232,6 @@ class ChatTable:
if chat is None: if chat is None:
return None return None
# Sanitize message content for null characters before upserting
if isinstance(message.get("content"), str):
message["content"] = sanitize_text_for_db(message["content"])
chat = chat.chat chat = chat.chat
history = chat.get("history", {}) history = chat.get("history", {})
@ -420,27 +266,6 @@ class ChatTable:
chat["history"] = history chat["history"] = history
return self.update_chat_by_id(id, chat) return self.update_chat_by_id(id, chat)
def add_message_files_by_id_and_message_id(
self, id: str, message_id: str, files: list[dict]
) -> list[dict]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
history = chat.get("history", {})
message_files = []
if message_id in history.get("messages", {}):
message_files = history["messages"][message_id].get("files", [])
message_files = message_files + files
history["messages"][message_id]["files"] = message_files
chat["history"] = history
self.update_chat_by_id(id, chat)
return message_files
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_db() as db: with get_db() as db:
# Get the existing chat to share # Get the existing chat to share
@ -455,9 +280,6 @@ class ChatTable:
"user_id": f"shared-{chat_id}", "user_id": f"shared-{chat_id}",
"title": chat.title, "title": chat.title,
"chat": chat.chat, "chat": chat.chat,
"meta": chat.meta,
"pinned": chat.pinned,
"folder_id": chat.folder_id,
"created_at": chat.created_at, "created_at": chat.created_at,
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
@ -489,9 +311,7 @@ class ChatTable:
shared_chat.title = chat.title shared_chat.title = chat.title
shared_chat.chat = chat.chat shared_chat.chat = chat.chat
shared_chat.meta = chat.meta
shared_chat.pinned = chat.pinned
shared_chat.folder_id = chat.folder_id
shared_chat.updated_at = int(time.time()) shared_chat.updated_at = int(time.time())
db.commit() db.commit()
db.refresh(shared_chat) db.refresh(shared_chat)
@ -510,15 +330,6 @@ class ChatTable:
except Exception: except Exception:
return False return False
def unarchive_all_chats_by_user_id(self, user_id: str) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(user_id=user_id).update({"archived": False})
db.commit()
return True
except Exception:
return False
def update_chat_share_id_by_id( def update_chat_share_id_by_id(
self, id: str, share_id: Optional[str] self, id: str, share_id: Optional[str]
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
@ -549,7 +360,6 @@ class ChatTable:
with get_db() as db: with get_db() as db:
chat = db.get(Chat, id) chat = db.get(Chat, id)
chat.archived = not chat.archived chat.archived = not chat.archived
chat.folder_id = None
chat.updated_at = int(time.time()) chat.updated_at = int(time.time())
db.commit() db.commit()
db.refresh(chat) db.refresh(chat)
@ -585,10 +395,7 @@ class ChatTable:
order_by = filter.get("order_by") order_by = filter.get("order_by")
direction = filter.get("direction") direction = filter.get("direction")
if order_by and direction: if order_by and direction and getattr(Chat, order_by):
if not getattr(Chat, order_by, None):
raise ValueError("Invalid order_by field")
if direction.lower() == "asc": if direction.lower() == "asc":
query = query.order_by(getattr(Chat, order_by).asc()) query = query.order_by(getattr(Chat, order_by).asc())
elif direction.lower() == "desc": elif direction.lower() == "desc":
@ -649,19 +456,12 @@ class ChatTable:
self, self,
user_id: str, user_id: str,
include_archived: bool = False, include_archived: bool = False,
include_folders: bool = False,
include_pinned: bool = False,
skip: Optional[int] = None, skip: Optional[int] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
) -> list[ChatTitleIdResponse]: ) -> list[ChatTitleIdResponse]:
with get_db() as db: with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id) query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
if not include_folders:
query = query.filter_by(folder_id=None)
if not include_pinned:
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
if not include_archived: if not include_archived:
query = query.filter_by(archived=False) query = query.filter_by(archived=False)
@ -706,15 +506,8 @@ class ChatTable:
def get_chat_by_id(self, id: str) -> Optional[ChatModel]: def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
with get_db() as db: with get_db() as db:
chat_item = db.get(Chat, id) chat = db.get(Chat, id)
if chat_item is None: return ChatModel.model_validate(chat)
return None
if self._sanitize_chat_row(chat_item):
db.commit()
db.refresh(chat_item)
return ChatModel.model_validate(chat_item)
except Exception: except Exception:
return None return None
@ -749,31 +542,14 @@ class ChatTable:
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id( def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
self, user_id: str, skip: Optional[int] = None, limit: Optional[int] = None
) -> ChatListResponse:
with get_db() as db: with get_db() as db:
query = ( all_chats = (
db.query(Chat) db.query(Chat)
.filter_by(user_id=user_id) .filter_by(user_id=user_id)
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
) )
return [ChatModel.model_validate(chat) for chat in all_chats]
total = query.count()
if skip is not None:
query = query.offset(skip)
if limit is not None:
query = query.limit(limit)
all_chats = query.all()
return ChatListResponse(
**{
"items": [ChatModel.model_validate(chat) for chat in all_chats],
"total": total,
}
)
def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]: def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db: with get_db() as db:
@ -804,7 +580,7 @@ class ChatTable:
""" """
Filters chats based on a search query using Python, allowing pagination using skip and limit. Filters chats based on a search query using Python, allowing pagination using skip and limit.
""" """
search_text = sanitize_text_for_db(search_text).lower().strip() search_text = search_text.lower().strip()
if not search_text: if not search_text:
return self.get_chat_list_by_user_id( return self.get_chat_list_by_user_id(
@ -820,45 +596,8 @@ class ChatTable:
if word.startswith("tag:") if word.startswith("tag:")
] ]
# Extract folder names - handle spaces and case insensitivity
folders = Folders.search_folders_by_names(
user_id,
[
word.replace("folder:", "")
for word in search_text_words
if word.startswith("folder:")
],
)
folder_ids = [folder.id for folder in folders]
is_pinned = None
if "pinned:true" in search_text_words:
is_pinned = True
elif "pinned:false" in search_text_words:
is_pinned = False
is_archived = None
if "archived:true" in search_text_words:
is_archived = True
elif "archived:false" in search_text_words:
is_archived = False
is_shared = None
if "shared:true" in search_text_words:
is_shared = True
elif "shared:false" in search_text_words:
is_shared = False
search_text_words = [ search_text_words = [
word word for word in search_text_words if not word.startswith("tag:")
for word in search_text_words
if (
not word.startswith("tag:")
and not word.startswith("folder:")
and not word.startswith("pinned:")
and not word.startswith("archived:")
and not word.startswith("shared:")
)
] ]
search_text = " ".join(search_text_words) search_text = " ".join(search_text_words)
@ -866,41 +605,30 @@ class ChatTable:
with get_db() as db: with get_db() as db:
query = db.query(Chat).filter(Chat.user_id == user_id) query = db.query(Chat).filter(Chat.user_id == user_id)
if is_archived is not None: if not include_archived:
query = query.filter(Chat.archived == is_archived)
elif not include_archived:
query = query.filter(Chat.archived == False) query = query.filter(Chat.archived == False)
if is_pinned is not None:
query = query.filter(Chat.pinned == is_pinned)
if is_shared is not None:
if is_shared:
query = query.filter(Chat.share_id.isnot(None))
else:
query = query.filter(Chat.share_id.is_(None))
if folder_ids:
query = query.filter(Chat.folder_id.in_(folder_ids))
query = query.order_by(Chat.updated_at.desc()) query = query.order_by(Chat.updated_at.desc())
# Check if the database dialect is either 'sqlite' or 'postgresql' # Check if the database dialect is either 'sqlite' or 'postgresql'
dialect_name = db.bind.dialect.name dialect_name = db.bind.dialect.name
if dialect_name == "sqlite": if dialect_name == "sqlite":
# SQLite case: using JSON1 extension for JSON searching # SQLite case: using JSON1 extension for JSON searching
sqlite_content_sql = (
"EXISTS ("
" SELECT 1 "
" FROM json_each(Chat.chat, '$.messages') AS message "
" WHERE LOWER(message.value->>'content') LIKE '%' || :content_key || '%'"
")"
)
sqlite_content_clause = text(sqlite_content_sql)
query = query.filter( query = query.filter(
or_( (
Chat.title.ilike(bindparam("title_key")), sqlite_content_clause Chat.title.ilike(
).params(title_key=f"%{search_text}%", content_key=search_text) f"%{search_text}%"
) # Case-insensitive search in title
| text(
"""
EXISTS (
SELECT 1
FROM json_each(Chat.chat, '$.messages') AS message
WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
)
"""
)
).params(search_text=search_text)
) )
# Check if there are any tags to filter, it should have all the tags # Check if there are any tags to filter, it should have all the tags
@ -934,32 +662,23 @@ class ChatTable:
) )
elif dialect_name == "postgresql": elif dialect_name == "postgresql":
# PostgreSQL doesn't allow null bytes in text. We filter those out by checking # PostgreSQL relies on proper JSON query for search
# the JSON representation for \u0000 before attempting text extraction
# Safety filter: JSON field must not contain \u0000
query = query.filter(text("Chat.chat::text NOT LIKE '%\\\\u0000%'"))
# Safety filter: title must not contain actual null bytes
query = query.filter(text("Chat.title::text NOT LIKE '%\\x00%'"))
postgres_content_sql = """
EXISTS (
SELECT 1
FROM json_array_elements(Chat.chat->'messages') AS message
WHERE json_typeof(message->'content') = 'string'
AND LOWER(message->>'content') LIKE '%' || :content_key || '%'
)
"""
postgres_content_clause = text(postgres_content_sql)
query = query.filter( query = query.filter(
or_( (
Chat.title.ilike(bindparam("title_key")), Chat.title.ilike(
postgres_content_clause, f"%{search_text}%"
) ) # Case-insensitive search in title
).params(title_key=f"%{search_text}%", content_key=search_text.lower()) | text(
"""
EXISTS (
SELECT 1
FROM json_array_elements(Chat.chat->'messages') AS message
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
)
"""
)
).params(search_text=search_text)
)
# Check if there are any tags to filter, it should have all the tags # Check if there are any tags to filter, it should have all the tags
if "none" in tag_ids: if "none" in tag_ids:
@ -1004,7 +723,7 @@ class ChatTable:
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_folder_id_and_user_id( def get_chats_by_folder_id_and_user_id(
self, folder_id: str, user_id: str, skip: int = 0, limit: int = 60 self, folder_id: str, user_id: str
) -> list[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: with get_db() as db:
query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id) query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id)
@ -1013,11 +732,6 @@ class ChatTable:
query = query.order_by(Chat.updated_at.desc()) query = query.order_by(Chat.updated_at.desc())
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all() all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
@ -1147,16 +861,6 @@ class ChatTable:
return count return count
def count_chats_by_folder_id_and_user_id(self, folder_id: str, user_id: str) -> int:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id)
query = query.filter_by(folder_id=folder_id)
count = query.count()
log.info(f"Count of chats for folder '{folder_id}': {count}")
return count
def delete_tag_by_id_and_user_id_and_tag_name( def delete_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str self, id: str, user_id: str, tag_name: str
) -> bool: ) -> bool:
@ -1234,20 +938,6 @@ class ChatTable:
except Exception: except Exception:
return False return False
def move_chats_by_user_id_and_folder_id(
self, user_id: str, folder_id: str, new_folder_id: Optional[str]
) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).update(
{"folder_id": new_folder_id}
)
db.commit()
return True
except Exception:
return False
def delete_shared_chats_by_user_id(self, user_id: str) -> bool: def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try: try:
with get_db() as db: with get_db() as db:
@ -1261,93 +951,5 @@ class ChatTable:
except Exception: except Exception:
return False return False
def insert_chat_files(
self, chat_id: str, message_id: str, file_ids: list[str], user_id: str
) -> Optional[list[ChatFileModel]]:
if not file_ids:
return None
chat_message_file_ids = [
item.id
for item in self.get_chat_files_by_chat_id_and_message_id(
chat_id, message_id
)
]
# Remove duplicates and existing file_ids
file_ids = list(
set(
[
file_id
for file_id in file_ids
if file_id and file_id not in chat_message_file_ids
]
)
)
if not file_ids:
return None
try:
with get_db() as db:
now = int(time.time())
chat_files = [
ChatFileModel(
id=str(uuid.uuid4()),
user_id=user_id,
chat_id=chat_id,
message_id=message_id,
file_id=file_id,
created_at=now,
updated_at=now,
)
for file_id in file_ids
]
results = [
ChatFile(**chat_file.model_dump()) for chat_file in chat_files
]
db.add_all(results)
db.commit()
return chat_files
except Exception:
return None
def get_chat_files_by_chat_id_and_message_id(
self, chat_id: str, message_id: str
) -> list[ChatFileModel]:
with get_db() as db:
all_chat_files = (
db.query(ChatFile)
.filter_by(chat_id=chat_id, message_id=message_id)
.order_by(ChatFile.created_at.asc())
.all()
)
return [
ChatFileModel.model_validate(chat_file) for chat_file in all_chat_files
]
def delete_chat_file(self, chat_id: str, file_id: str) -> bool:
try:
with get_db() as db:
db.query(ChatFile).filter_by(chat_id=chat_id, file_id=file_id).delete()
db.commit()
return True
except Exception:
return False
def get_shared_chats_by_file_id(self, file_id: str) -> list[ChatModel]:
with get_db() as db:
# Join Chat and ChatFile tables to get shared chats associated with the file_id
all_chats = (
db.query(Chat)
.join(ChatFile, Chat.id == ChatFile.chat_id)
.filter(ChatFile.file_id == file_id, Chat.share_id.isnot(None))
.all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
Chats = ChatTable() Chats = ChatTable()

View file

@ -4,12 +4,14 @@ import uuid
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, get_db
from open_webui.models.users import User from open_webui.models.chats import Chats
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
@ -19,7 +21,7 @@ log = logging.getLogger(__name__)
class Feedback(Base): class Feedback(Base):
__tablename__ = "feedback" __tablename__ = "feedback"
id = Column(Text, primary_key=True, unique=True) id = Column(Text, primary_key=True)
user_id = Column(Text) user_id = Column(Text)
version = Column(BigInteger, default=0) version = Column(BigInteger, default=0)
type = Column(Text) type = Column(Text)
@ -60,13 +62,6 @@ class FeedbackResponse(BaseModel):
updated_at: int updated_at: int
class FeedbackIdResponse(BaseModel):
id: str
user_id: str
created_at: int
updated_at: int
class RatingData(BaseModel): class RatingData(BaseModel):
rating: Optional[str | int] = None rating: Optional[str | int] = None
model_id: Optional[str] = None model_id: Optional[str] = None
@ -97,28 +92,6 @@ class FeedbackForm(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class UserResponse(BaseModel):
id: str
name: str
email: str
role: str = "pending"
last_active_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
class FeedbackUserResponse(FeedbackResponse):
user: Optional[UserResponse] = None
class FeedbackListResponse(BaseModel):
items: list[FeedbackUserResponse]
total: int
class FeedbackTable: class FeedbackTable:
def insert_new_feedback( def insert_new_feedback(
self, user_id: str, form_data: FeedbackForm self, user_id: str, form_data: FeedbackForm
@ -170,70 +143,6 @@ class FeedbackTable:
except Exception: except Exception:
return None return None
def get_feedback_items(
self, filter: dict = {}, skip: int = 0, limit: int = 30
) -> FeedbackListResponse:
with get_db() as db:
query = db.query(Feedback, User).join(User, Feedback.user_id == User.id)
if filter:
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by == "username":
if direction == "asc":
query = query.order_by(User.name.asc())
else:
query = query.order_by(User.name.desc())
elif order_by == "model_id":
# it's stored in feedback.data['model_id']
if direction == "asc":
query = query.order_by(
Feedback.data["model_id"].as_string().asc()
)
else:
query = query.order_by(
Feedback.data["model_id"].as_string().desc()
)
elif order_by == "rating":
# it's stored in feedback.data['rating']
if direction == "asc":
query = query.order_by(
Feedback.data["rating"].as_string().asc()
)
else:
query = query.order_by(
Feedback.data["rating"].as_string().desc()
)
elif order_by == "updated_at":
if direction == "asc":
query = query.order_by(Feedback.updated_at.asc())
else:
query = query.order_by(Feedback.updated_at.desc())
else:
query = query.order_by(Feedback.created_at.desc())
# Count BEFORE pagination
total = query.count()
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
items = query.all()
feedbacks = []
for feedback, user in items:
feedback_model = FeedbackModel.model_validate(feedback)
user_model = UserResponse.model_validate(user)
feedbacks.append(
FeedbackUserResponse(**feedback_model.model_dump(), user=user_model)
)
return FeedbackListResponse(items=feedbacks, total=total)
def get_all_feedbacks(self) -> list[FeedbackModel]: def get_all_feedbacks(self) -> list[FeedbackModel]:
with get_db() as db: with get_db() as db:
return [ return [

View file

@ -3,10 +3,12 @@ import time
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON from sqlalchemy import BigInteger, Column, String, Text, JSON
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
# Files DB Schema # Files DB Schema
@ -15,7 +17,7 @@ log = logging.getLogger(__name__)
class File(Base): class File(Base):
__tablename__ = "file" __tablename__ = "file"
id = Column(String, primary_key=True, unique=True) id = Column(String, primary_key=True)
user_id = Column(String) user_id = Column(String)
hash = Column(Text, nullable=True) hash = Column(Text, nullable=True)
@ -80,8 +82,7 @@ class FileModelResponse(BaseModel):
class FileMetadataResponse(BaseModel): class FileMetadataResponse(BaseModel):
id: str id: str
hash: Optional[str] = None meta: dict
meta: Optional[dict] = None
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
@ -96,17 +97,6 @@ class FileForm(BaseModel):
access_control: Optional[dict] = None access_control: Optional[dict] = None
class FileUpdateForm(BaseModel):
hash: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
class FileListResponse(BaseModel):
items: list[FileModel]
total: int
class FilesTable: class FilesTable:
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]: def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
with get_db() as db: with get_db() as db:
@ -140,24 +130,12 @@ class FilesTable:
except Exception: except Exception:
return None return None
def get_file_by_id_and_user_id(self, id: str, user_id: str) -> Optional[FileModel]:
with get_db() as db:
try:
file = db.query(File).filter_by(id=id, user_id=user_id).first()
if file:
return FileModel.model_validate(file)
else:
return None
except Exception:
return None
def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]: def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]:
with get_db() as db: with get_db() as db:
try: try:
file = db.get(File, id) file = db.get(File, id)
return FileMetadataResponse( return FileMetadataResponse(
id=file.id, id=file.id,
hash=file.hash,
meta=file.meta, meta=file.meta,
created_at=file.created_at, created_at=file.created_at,
updated_at=file.updated_at, updated_at=file.updated_at,
@ -169,15 +147,6 @@ class FilesTable:
with get_db() as db: with get_db() as db:
return [FileModel.model_validate(file) for file in db.query(File).all()] return [FileModel.model_validate(file) for file in db.query(File).all()]
def check_access_by_user_id(self, id, user_id, permission="write") -> bool:
file = self.get_file_by_id(id)
if not file:
return False
if file.user_id == user_id:
return True
# Implement additional access control logic here as needed
return False
def get_files_by_ids(self, ids: list[str]) -> list[FileModel]: def get_files_by_ids(self, ids: list[str]) -> list[FileModel]:
with get_db() as db: with get_db() as db:
return [ return [
@ -193,14 +162,11 @@ class FilesTable:
return [ return [
FileMetadataResponse( FileMetadataResponse(
id=file.id, id=file.id,
hash=file.hash,
meta=file.meta, meta=file.meta,
created_at=file.created_at, created_at=file.created_at,
updated_at=file.updated_at, updated_at=file.updated_at,
) )
for file in db.query( for file in db.query(File)
File.id, File.hash, File.meta, File.created_at, File.updated_at
)
.filter(File.id.in_(ids)) .filter(File.id.in_(ids))
.order_by(File.updated_at.desc()) .order_by(File.updated_at.desc())
.all() .all()
@ -213,35 +179,11 @@ class FilesTable:
for file in db.query(File).filter_by(user_id=user_id).all() for file in db.query(File).filter_by(user_id=user_id).all()
] ]
def update_file_by_id(
self, id: str, form_data: FileUpdateForm
) -> Optional[FileModel]:
with get_db() as db:
try:
file = db.query(File).filter_by(id=id).first()
if form_data.hash is not None:
file.hash = form_data.hash
if form_data.data is not None:
file.data = {**(file.data if file.data else {}), **form_data.data}
if form_data.meta is not None:
file.meta = {**(file.meta if file.meta else {}), **form_data.meta}
file.updated_at = int(time.time())
db.commit()
return FileModel.model_validate(file)
except Exception as e:
log.exception(f"Error updating file completely by id: {e}")
return None
def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]: def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]:
with get_db() as db: with get_db() as db:
try: try:
file = db.query(File).filter_by(id=id).first() file = db.query(File).filter_by(id=id).first()
file.hash = hash file.hash = hash
file.updated_at = int(time.time())
db.commit() db.commit()
return FileModel.model_validate(file) return FileModel.model_validate(file)
@ -253,7 +195,6 @@ class FilesTable:
try: try:
file = db.query(File).filter_by(id=id).first() file = db.query(File).filter_by(id=id).first()
file.data = {**(file.data if file.data else {}), **data} file.data = {**(file.data if file.data else {}), **data}
file.updated_at = int(time.time())
db.commit() db.commit()
return FileModel.model_validate(file) return FileModel.model_validate(file)
except Exception as e: except Exception as e:
@ -265,7 +206,6 @@ class FilesTable:
try: try:
file = db.query(File).filter_by(id=id).first() file = db.query(File).filter_by(id=id).first()
file.meta = {**(file.meta if file.meta else {}), **meta} file.meta = {**(file.meta if file.meta else {}), **meta}
file.updated_at = int(time.time())
db.commit() db.commit()
return FileModel.model_validate(file) return FileModel.model_validate(file)
except Exception: except Exception:

View file

@ -2,16 +2,18 @@ import logging
import time import time
import uuid import uuid
from typing import Optional from typing import Optional
import re
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean, func
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, get_db
from open_webui.models.chats import Chats
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
from open_webui.utils.access_control import get_permissions
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
@ -21,13 +23,12 @@ log = logging.getLogger(__name__)
class Folder(Base): class Folder(Base):
__tablename__ = "folder" __tablename__ = "folder"
id = Column(Text, primary_key=True, unique=True) id = Column(Text, primary_key=True)
parent_id = Column(Text, nullable=True) parent_id = Column(Text, nullable=True)
user_id = Column(Text) user_id = Column(Text)
name = Column(Text) name = Column(Text)
items = Column(JSON, nullable=True) items = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True) meta = Column(JSON, nullable=True)
data = Column(JSON, nullable=True)
is_expanded = Column(Boolean, default=False) is_expanded = Column(Boolean, default=False)
created_at = Column(BigInteger) created_at = Column(BigInteger)
updated_at = Column(BigInteger) updated_at = Column(BigInteger)
@ -40,7 +41,6 @@ class FolderModel(BaseModel):
name: str name: str
items: Optional[dict] = None items: Optional[dict] = None
meta: Optional[dict] = None meta: Optional[dict] = None
data: Optional[dict] = None
is_expanded: bool = False is_expanded: bool = False
created_at: int created_at: int
updated_at: int updated_at: int
@ -48,20 +48,6 @@ class FolderModel(BaseModel):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class FolderMetadataResponse(BaseModel):
icon: Optional[str] = None
class FolderNameIdResponse(BaseModel):
id: str
name: str
meta: Optional[FolderMetadataResponse] = None
parent_id: Optional[str] = None
is_expanded: bool = False
created_at: int
updated_at: int
#################### ####################
# Forms # Forms
#################### ####################
@ -69,21 +55,12 @@ class FolderNameIdResponse(BaseModel):
class FolderForm(BaseModel): class FolderForm(BaseModel):
name: str name: str
data: Optional[dict] = None
meta: Optional[dict] = None
model_config = ConfigDict(extra="allow")
class FolderUpdateForm(BaseModel):
name: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class FolderTable: class FolderTable:
def insert_new_folder( def insert_new_folder(
self, user_id: str, form_data: FolderForm, parent_id: Optional[str] = None self, user_id: str, name: str, parent_id: Optional[str] = None
) -> Optional[FolderModel]: ) -> Optional[FolderModel]:
with get_db() as db: with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
@ -91,7 +68,7 @@ class FolderTable:
**{ **{
"id": id, "id": id,
"user_id": user_id, "user_id": user_id,
**(form_data.model_dump(exclude_unset=True) or {}), "name": name,
"parent_id": parent_id, "parent_id": parent_id,
"created_at": int(time.time()), "created_at": int(time.time()),
"updated_at": int(time.time()), "updated_at": int(time.time()),
@ -126,7 +103,7 @@ class FolderTable:
def get_children_folders_by_id_and_user_id( def get_children_folders_by_id_and_user_id(
self, id: str, user_id: str self, id: str, user_id: str
) -> Optional[list[FolderModel]]: ) -> Optional[FolderModel]:
try: try:
with get_db() as db: with get_db() as db:
folders = [] folders = []
@ -210,8 +187,8 @@ class FolderTable:
log.error(f"update_folder: {e}") log.error(f"update_folder: {e}")
return return
def update_folder_by_id_and_user_id( def update_folder_name_by_id_and_user_id(
self, id: str, user_id: str, form_data: FolderUpdateForm self, id: str, user_id: str, name: str
) -> Optional[FolderModel]: ) -> Optional[FolderModel]:
try: try:
with get_db() as db: with get_db() as db:
@ -220,35 +197,18 @@ class FolderTable:
if not folder: if not folder:
return None return None
form_data = form_data.model_dump(exclude_unset=True)
existing_folder = ( existing_folder = (
db.query(Folder) db.query(Folder)
.filter_by( .filter_by(name=name, parent_id=folder.parent_id, user_id=user_id)
name=form_data.get("name"),
parent_id=folder.parent_id,
user_id=user_id,
)
.first() .first()
) )
if existing_folder and existing_folder.id != id: if existing_folder:
return None return None
folder.name = form_data.get("name", folder.name) folder.name = name
if "data" in form_data:
folder.data = {
**(folder.data or {}),
**form_data["data"],
}
if "meta" in form_data:
folder.meta = {
**(folder.meta or {}),
**form_data["meta"],
}
folder.updated_at = int(time.time()) folder.updated_at = int(time.time())
db.commit() db.commit()
return FolderModel.model_validate(folder) return FolderModel.model_validate(folder)
@ -276,15 +236,18 @@ class FolderTable:
log.error(f"update_folder: {e}") log.error(f"update_folder: {e}")
return return
def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> list[str]: def delete_folder_by_id_and_user_id(
self, id: str, user_id: str, delete_chats=True
) -> bool:
try: try:
folder_ids = []
with get_db() as db: with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder: if not folder:
return folder_ids return False
folder_ids.append(folder.id) if delete_chats:
# Delete all chats in the folder
Chats.delete_chats_by_user_id_and_folder_id(user_id, folder.id)
# Delete all children folders # Delete all children folders
def delete_children(folder): def delete_children(folder):
@ -292,9 +255,12 @@ class FolderTable:
folder.id, user_id folder.id, user_id
) )
for folder_child in folder_children: for folder_child in folder_children:
if delete_chats:
Chats.delete_chats_by_user_id_and_folder_id(
user_id, folder_child.id
)
delete_children(folder_child) delete_children(folder_child)
folder_ids.append(folder_child.id)
folder = db.query(Folder).filter_by(id=folder_child.id).first() folder = db.query(Folder).filter_by(id=folder_child.id).first()
db.delete(folder) db.delete(folder)
@ -303,62 +269,10 @@ class FolderTable:
delete_children(folder) delete_children(folder)
db.delete(folder) db.delete(folder)
db.commit() db.commit()
return folder_ids return True
except Exception as e: except Exception as e:
log.error(f"delete_folder: {e}") log.error(f"delete_folder: {e}")
return [] return False
def normalize_folder_name(self, name: str) -> str:
# Replace _ and space with a single space, lower case, collapse multiple spaces
name = re.sub(r"[\s_]+", " ", name)
return name.strip().lower()
def search_folders_by_names(
self, user_id: str, queries: list[str]
) -> list[FolderModel]:
"""
Search for folders for a user where the name matches any of the queries, treating _ and space as equivalent, case-insensitive.
"""
normalized_queries = [self.normalize_folder_name(q) for q in queries]
if not normalized_queries:
return []
results = {}
with get_db() as db:
folders = db.query(Folder).filter_by(user_id=user_id).all()
for folder in folders:
if self.normalize_folder_name(folder.name) in normalized_queries:
results[folder.id] = FolderModel.model_validate(folder)
# get children folders
children = self.get_children_folders_by_id_and_user_id(
folder.id, user_id
)
for child in children:
results[child.id] = child
# Return the results as a list
if not results:
return []
else:
results = list(results.values())
return results
def search_folders_by_name_contains(
self, user_id: str, query: str
) -> list[FolderModel]:
"""
Partial match: normalized name contains (as substring) the normalized query.
"""
normalized_query = self.normalize_folder_name(query)
results = []
with get_db() as db:
folders = db.query(Folder).filter_by(user_id=user_id).all()
for folder in folders:
norm_name = self.normalize_folder_name(folder.name)
if normalized_query in norm_name:
results.append(FolderModel.model_validate(folder))
return results
Folders = FolderTable() Folders = FolderTable()

View file

@ -3,11 +3,13 @@ import time
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.users import Users, UserModel from open_webui.models.users import Users
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, Index from sqlalchemy import BigInteger, Boolean, Column, String, Text
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
# Functions DB Schema # Functions DB Schema
@ -17,7 +19,7 @@ log = logging.getLogger(__name__)
class Function(Base): class Function(Base):
__tablename__ = "function" __tablename__ = "function"
id = Column(String, primary_key=True, unique=True) id = Column(String, primary_key=True)
user_id = Column(String) user_id = Column(String)
name = Column(Text) name = Column(Text)
type = Column(Text) type = Column(Text)
@ -29,13 +31,10 @@ class Function(Base):
updated_at = Column(BigInteger) updated_at = Column(BigInteger)
created_at = Column(BigInteger) created_at = Column(BigInteger)
__table_args__ = (Index("is_global_idx", "is_global"),)
class FunctionMeta(BaseModel): class FunctionMeta(BaseModel):
description: Optional[str] = None description: Optional[str] = None
manifest: Optional[dict] = {} manifest: Optional[dict] = {}
model_config = ConfigDict(extra="allow")
class FunctionModel(BaseModel): class FunctionModel(BaseModel):
@ -53,31 +52,11 @@ class FunctionModel(BaseModel):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class FunctionWithValvesModel(BaseModel):
id: str
user_id: str
name: str
type: str
content: str
meta: FunctionMeta
valves: Optional[dict] = None
is_active: bool = False
is_global: bool = False
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
#################### ####################
class FunctionUserResponse(FunctionModel):
user: Optional[UserModel] = None
class FunctionResponse(BaseModel): class FunctionResponse(BaseModel):
id: str id: str
user_id: str user_id: str
@ -130,8 +109,8 @@ class FunctionsTable:
return None return None
def sync_functions( def sync_functions(
self, user_id: str, functions: list[FunctionWithValvesModel] self, user_id: str, functions: list[FunctionModel]
) -> list[FunctionWithValvesModel]: ) -> list[FunctionModel]:
# Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present. # Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present.
try: try:
with get_db() as db: with get_db() as db:
@ -185,48 +164,19 @@ class FunctionsTable:
except Exception: except Exception:
return None return None
def get_functions( def get_functions(self, active_only=False) -> list[FunctionModel]:
self, active_only=False, include_valves=False
) -> list[FunctionModel | FunctionWithValvesModel]:
with get_db() as db: with get_db() as db:
if active_only: if active_only:
functions = db.query(Function).filter_by(is_active=True).all()
else:
functions = db.query(Function).all()
if include_valves:
return [ return [
FunctionWithValvesModel.model_validate(function) FunctionModel.model_validate(function)
for function in functions for function in db.query(Function).filter_by(is_active=True).all()
] ]
else: else:
return [ return [
FunctionModel.model_validate(function) for function in functions FunctionModel.model_validate(function)
for function in db.query(Function).all()
] ]
def get_function_list(self) -> list[FunctionUserResponse]:
with get_db() as db:
functions = db.query(Function).order_by(Function.updated_at.desc()).all()
user_ids = list(set(func.user_id for func in functions))
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
users_dict = {user.id: user for user in users}
return [
FunctionUserResponse.model_validate(
{
**FunctionModel.model_validate(func).model_dump(),
"user": (
users_dict.get(func.user_id).model_dump()
if func.user_id in users_dict
else None
),
}
)
for func in functions
]
def get_functions_by_type( def get_functions_by_type(
self, type: str, active_only=False self, type: str, active_only=False
) -> list[FunctionModel]: ) -> list[FunctionModel]:
@ -285,29 +235,6 @@ class FunctionsTable:
except Exception: except Exception:
return None return None
def update_function_metadata_by_id(
self, id: str, metadata: dict
) -> Optional[FunctionModel]:
with get_db() as db:
try:
function = db.get(Function, id)
if function:
if function.meta:
function.meta = {**function.meta, **metadata}
else:
function.meta = metadata
function.updated_at = int(time.time())
db.commit()
db.refresh(function)
return self.get_function_by_id(id)
else:
return None
except Exception as e:
log.exception(f"Error updating function metadata by id {id}: {e}")
return None
def get_user_valves_by_id_and_user_id( def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str self, id: str, user_id: str
) -> Optional[dict]: ) -> Optional[dict]:
@ -323,7 +250,9 @@ class FunctionsTable:
return user_settings["functions"]["valves"].get(id, {}) return user_settings["functions"]["valves"].get(id, {})
except Exception as e: except Exception as e:
log.exception(f"Error getting user values by id {id} and user id {user_id}") log.exception(
f"Error getting user values by id {id} and user id {user_id}: {e}"
)
return None return None
def update_user_valves_by_id_and_user_id( def update_user_valves_by_id_and_user_id(

View file

@ -5,26 +5,17 @@ from typing import Optional
import uuid import uuid
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.files import FileMetadataResponse from open_webui.models.files import FileMetadataResponse
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import ( from sqlalchemy import BigInteger, Column, String, Text, JSON, func
BigInteger,
Column,
String,
Text,
JSON,
and_,
func,
ForeignKey,
cast,
or_,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
# UserGroup DB Schema # UserGroup DB Schema
@ -44,12 +35,14 @@ class Group(Base):
meta = Column(JSON, nullable=True) meta = Column(JSON, nullable=True)
permissions = Column(JSON, nullable=True) permissions = Column(JSON, nullable=True)
user_ids = Column(JSON, nullable=True)
created_at = Column(BigInteger) created_at = Column(BigInteger)
updated_at = Column(BigInteger) updated_at = Column(BigInteger)
class GroupModel(BaseModel): class GroupModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str id: str
user_id: str user_id: str
@ -60,62 +53,38 @@ class GroupModel(BaseModel):
meta: Optional[dict] = None meta: Optional[dict] = None
permissions: Optional[dict] = None permissions: Optional[dict] = None
user_ids: list[str] = []
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
class GroupMember(Base):
__tablename__ = "group_member"
id = Column(Text, unique=True, primary_key=True)
group_id = Column(
Text,
ForeignKey("group.id", ondelete="CASCADE"),
nullable=False,
)
user_id = Column(Text, nullable=False)
created_at = Column(BigInteger, nullable=True)
updated_at = Column(BigInteger, nullable=True)
class GroupMemberModel(BaseModel):
id: str
group_id: str
user_id: str
created_at: Optional[int] = None # timestamp in epoch
updated_at: Optional[int] = None # timestamp in epoch
#################### ####################
# Forms # Forms
#################### ####################
class GroupResponse(GroupModel): class GroupResponse(BaseModel):
member_count: Optional[int] = None id: str
user_id: str
name: str
description: str
permissions: Optional[dict] = None
data: Optional[dict] = None
meta: Optional[dict] = None
user_ids: list[str] = []
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
class GroupForm(BaseModel): class GroupForm(BaseModel):
name: str name: str
description: str description: str
permissions: Optional[dict] = None permissions: Optional[dict] = None
data: Optional[dict] = None
class UserIdsForm(BaseModel):
user_ids: Optional[list[str]] = None
class GroupUpdateForm(GroupForm): class GroupUpdateForm(GroupForm):
pass user_ids: Optional[list[str]] = None
class GroupListResponse(BaseModel):
items: list[GroupResponse] = []
total: int = 0
class GroupTable: class GroupTable:
@ -146,94 +115,24 @@ class GroupTable:
except Exception: except Exception:
return None return None
def get_all_groups(self) -> list[GroupModel]: def get_groups(self) -> list[GroupModel]:
with get_db() as db: with get_db() as db:
groups = db.query(Group).order_by(Group.updated_at.desc()).all()
return [GroupModel.model_validate(group) for group in groups]
def get_groups(self, filter) -> list[GroupResponse]:
with get_db() as db:
query = db.query(Group)
if filter:
if "query" in filter:
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
if "member_id" in filter:
query = query.join(
GroupMember, GroupMember.group_id == Group.id
).filter(GroupMember.user_id == filter["member_id"])
if "share" in filter:
share_value = filter["share"]
json_share = Group.data["config"]["share"].as_boolean()
if share_value:
query = query.filter(
or_(
Group.data.is_(None),
json_share.is_(None),
json_share == True,
)
)
else:
query = query.filter(
and_(Group.data.isnot(None), json_share == False)
)
groups = query.order_by(Group.updated_at.desc()).all()
return [ return [
GroupResponse.model_validate( GroupModel.model_validate(group)
{ for group in db.query(Group).order_by(Group.updated_at.desc()).all()
**GroupModel.model_validate(group).model_dump(),
"member_count": self.get_group_member_count_by_id(group.id),
}
)
for group in groups
] ]
def search_groups(
self, filter: Optional[dict] = None, skip: int = 0, limit: int = 30
) -> GroupListResponse:
with get_db() as db:
query = db.query(Group)
if filter:
if "query" in filter:
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
if "member_id" in filter:
query = query.join(
GroupMember, GroupMember.group_id == Group.id
).filter(GroupMember.user_id == filter["member_id"])
if "share" in filter:
# 'share' is stored in data JSON, support both sqlite and postgres
share_value = filter["share"]
print("Filtering by share:", share_value)
query = query.filter(
Group.data.op("->>")("share") == str(share_value)
)
total = query.count()
query = query.order_by(Group.updated_at.desc())
groups = query.offset(skip).limit(limit).all()
return {
"items": [
GroupResponse.model_validate(
**GroupModel.model_validate(group).model_dump(),
member_count=self.get_group_member_count_by_id(group.id),
)
for group in groups
],
"total": total,
}
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]: def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
with get_db() as db: with get_db() as db:
return [ return [
GroupModel.model_validate(group) GroupModel.model_validate(group)
for group in db.query(Group) for group in db.query(Group)
.join(GroupMember, GroupMember.group_id == Group.id) .filter(
.filter(GroupMember.user_id == user_id) func.json_array_length(Group.user_ids) > 0
) # Ensure array exists
.filter(
Group.user_ids.cast(String).like(f'%"{user_id}"%')
) # String-based check
.order_by(Group.updated_at.desc()) .order_by(Group.updated_at.desc())
.all() .all()
] ]
@ -246,63 +145,12 @@ class GroupTable:
except Exception: except Exception:
return None return None
def get_group_user_ids_by_id(self, id: str) -> Optional[list[str]]: def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
with get_db() as db: group = self.get_group_by_id(id)
members = ( if group:
db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all() return group.user_ids
) else:
return None
if not members:
return None
return [m[0] for m in members]
def get_group_user_ids_by_ids(self, group_ids: list[str]) -> dict[str, list[str]]:
with get_db() as db:
members = (
db.query(GroupMember.group_id, GroupMember.user_id)
.filter(GroupMember.group_id.in_(group_ids))
.all()
)
group_user_ids: dict[str, list[str]] = {
group_id: [] for group_id in group_ids
}
for group_id, user_id in members:
group_user_ids[group_id].append(user_id)
return group_user_ids
def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str]) -> None:
with get_db() as db:
# Delete existing members
db.query(GroupMember).filter(GroupMember.group_id == group_id).delete()
# Insert new members
now = int(time.time())
new_members = [
GroupMember(
id=str(uuid.uuid4()),
group_id=group_id,
user_id=user_id,
created_at=now,
updated_at=now,
)
for user_id in user_ids
]
db.add_all(new_members)
db.commit()
def get_group_member_count_by_id(self, id: str) -> int:
with get_db() as db:
count = (
db.query(func.count(GroupMember.user_id))
.filter(GroupMember.group_id == id)
.scalar()
)
return count if count else 0
def update_group_by_id( def update_group_by_id(
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
@ -343,189 +191,59 @@ class GroupTable:
def remove_user_from_all_groups(self, user_id: str) -> bool: def remove_user_from_all_groups(self, user_id: str) -> bool:
with get_db() as db: with get_db() as db:
try: try:
# Find all groups the user belongs to groups = self.get_groups_by_member_id(user_id)
groups = (
db.query(Group)
.join(GroupMember, GroupMember.group_id == Group.id)
.filter(GroupMember.user_id == user_id)
.all()
)
# Remove the user from each group
for group in groups: for group in groups:
db.query(GroupMember).filter( group.user_ids.remove(user_id)
GroupMember.group_id == group.id, GroupMember.user_id == user_id
).delete()
db.query(Group).filter_by(id=group.id).update( db.query(Group).filter_by(id=group.id).update(
{"updated_at": int(time.time())} {
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
) )
db.commit()
db.commit()
return True return True
except Exception: except Exception:
db.rollback()
return False return False
def create_groups_by_group_names( def sync_user_groups_by_group_names(
self, user_id: str, group_names: list[str] self, user_id: str, group_names: list[str]
) -> list[GroupModel]: ) -> bool:
# check for existing groups
existing_groups = self.get_all_groups()
existing_group_names = {group.name for group in existing_groups}
new_groups = []
with get_db() as db:
for group_name in group_names:
if group_name not in existing_group_names:
new_group = GroupModel(
id=str(uuid.uuid4()),
user_id=user_id,
name=group_name,
description="",
created_at=int(time.time()),
updated_at=int(time.time()),
)
try:
result = Group(**new_group.model_dump())
db.add(result)
db.commit()
db.refresh(result)
new_groups.append(GroupModel.model_validate(result))
except Exception as e:
log.exception(e)
continue
return new_groups
def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
with get_db() as db: with get_db() as db:
try: try:
now = int(time.time()) groups = db.query(Group).filter(Group.name.in_(group_names)).all()
group_ids = [group.id for group in groups]
# 1. Groups that SHOULD contain the user # Remove user from groups not in the new list
target_groups = ( existing_groups = self.get_groups_by_member_id(user_id)
db.query(Group).filter(Group.name.in_(group_names)).all()
)
target_group_ids = {g.id for g in target_groups}
# 2. Groups the user is CURRENTLY in for group in existing_groups:
existing_group_ids = { if group.id not in group_ids:
g.id group.user_ids.remove(user_id)
for g in db.query(Group) db.query(Group).filter_by(id=group.id).update(
.join(GroupMember, GroupMember.group_id == Group.id) {
.filter(GroupMember.user_id == user_id) "user_ids": group.user_ids,
.all() "updated_at": int(time.time()),
} }
# 3. Determine adds + removals
groups_to_add = target_group_ids - existing_group_ids
groups_to_remove = existing_group_ids - target_group_ids
# 4. Remove in one bulk delete
if groups_to_remove:
db.query(GroupMember).filter(
GroupMember.user_id == user_id,
GroupMember.group_id.in_(groups_to_remove),
).delete(synchronize_session=False)
db.query(Group).filter(Group.id.in_(groups_to_remove)).update(
{"updated_at": now}, synchronize_session=False
)
# 5. Bulk insert missing memberships
for group_id in groups_to_add:
db.add(
GroupMember(
id=str(uuid.uuid4()),
group_id=group_id,
user_id=user_id,
created_at=now,
updated_at=now,
) )
)
if groups_to_add: # Add user to new groups
db.query(Group).filter(Group.id.in_(groups_to_add)).update( for group in groups:
{"updated_at": now}, synchronize_session=False if user_id not in group.user_ids:
) group.user_ids.append(user_id)
db.query(Group).filter_by(id=group.id).update(
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
)
db.commit() db.commit()
return True return True
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
db.rollback()
return False return False
def add_users_to_group(
self, id: str, user_ids: Optional[list[str]] = None
) -> Optional[GroupModel]:
try:
with get_db() as db:
group = db.query(Group).filter_by(id=id).first()
if not group:
return None
now = int(time.time())
for user_id in user_ids or []:
try:
db.add(
GroupMember(
id=str(uuid.uuid4()),
group_id=id,
user_id=user_id,
created_at=now,
updated_at=now,
)
)
db.flush() # Detect unique constraint violation early
except Exception:
db.rollback() # Clear failed INSERT
db.begin() # Start a new transaction
continue # Duplicate → ignore
group.updated_at = now
db.commit()
db.refresh(group)
return GroupModel.model_validate(group)
except Exception as e:
log.exception(e)
return None
def remove_users_from_group(
self, id: str, user_ids: Optional[list[str]] = None
) -> Optional[GroupModel]:
try:
with get_db() as db:
group = db.query(Group).filter_by(id=id).first()
if not group:
return None
if not user_ids:
return GroupModel.model_validate(group)
# Remove each user from group_member
for user_id in user_ids:
db.query(GroupMember).filter(
GroupMember.group_id == id, GroupMember.user_id == user_id
).delete()
# Update group timestamp
group.updated_at = int(time.time())
db.commit()
db.refresh(group)
return GroupModel.model_validate(group)
except Exception as e:
log.exception(e)
return None
Groups = GroupTable() Groups = GroupTable()

View file

@ -5,34 +5,19 @@ from typing import Optional
import uuid import uuid
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.files import ( from open_webui.models.files import FileMetadataResponse
File, from open_webui.models.users import Users, UserResponse
FileModel,
FileMetadataResponse,
FileModelResponse,
)
from open_webui.models.groups import Groups
from open_webui.models.users import User, UserModel, Users, UserResponse
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import ( from sqlalchemy import BigInteger, Column, String, Text, JSON
BigInteger,
Column,
ForeignKey,
String,
Text,
JSON,
UniqueConstraint,
or_,
)
from open_webui.utils.access_control import has_access from open_webui.utils.access_control import has_access
from open_webui.utils.db.access_control import has_permission
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
# Knowledge DB Schema # Knowledge DB Schema
@ -48,7 +33,9 @@ class Knowledge(Base):
name = Column(Text) name = Column(Text)
description = Column(Text) description = Column(Text)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True) meta = Column(JSON, nullable=True)
access_control = Column(JSON, nullable=True) # Controls data access levels. access_control = Column(JSON, nullable=True) # Controls data access levels.
# Defines access control rules for this entry. # Defines access control rules for this entry.
# - `None`: Public access, available to all users with the "user" role. # - `None`: Public access, available to all users with the "user" role.
@ -79,6 +66,7 @@ class KnowledgeModel(BaseModel):
name: str name: str
description: str description: str
data: Optional[dict] = None
meta: Optional[dict] = None meta: Optional[dict] = None
access_control: Optional[dict] = None access_control: Optional[dict] = None
@ -87,42 +75,11 @@ class KnowledgeModel(BaseModel):
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
class KnowledgeFile(Base):
__tablename__ = "knowledge_file"
id = Column(Text, unique=True, primary_key=True)
knowledge_id = Column(
Text, ForeignKey("knowledge.id", ondelete="CASCADE"), nullable=False
)
file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False)
user_id = Column(Text, nullable=False)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
__table_args__ = (
UniqueConstraint(
"knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file"
),
)
class KnowledgeFileModel(BaseModel):
id: str
knowledge_id: str
file_id: str
user_id: str
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
#################### ####################
class KnowledgeUserModel(KnowledgeModel): class KnowledgeUserModel(KnowledgeModel):
user: Optional[UserResponse] = None user: Optional[UserResponse] = None
@ -132,29 +89,16 @@ class KnowledgeResponse(KnowledgeModel):
class KnowledgeUserResponse(KnowledgeUserModel): class KnowledgeUserResponse(KnowledgeUserModel):
pass files: Optional[list[FileMetadataResponse | dict]] = None
class KnowledgeForm(BaseModel): class KnowledgeForm(BaseModel):
name: str name: str
description: str description: str
data: Optional[dict] = None
access_control: Optional[dict] = None access_control: Optional[dict] = None
class FileUserResponse(FileModelResponse):
user: Optional[UserResponse] = None
class KnowledgeListResponse(BaseModel):
items: list[KnowledgeUserModel]
total: int
class KnowledgeFileListResponse(BaseModel):
items: list[FileUserResponse]
total: int
class KnowledgeTable: class KnowledgeTable:
def insert_new_knowledge( def insert_new_knowledge(
self, user_id: str, form_data: KnowledgeForm self, user_id: str, form_data: KnowledgeForm
@ -182,21 +126,13 @@ class KnowledgeTable:
except Exception: except Exception:
return None return None
def get_knowledge_bases( def get_knowledge_bases(self) -> list[KnowledgeUserModel]:
self, skip: int = 0, limit: int = 30
) -> list[KnowledgeUserModel]:
with get_db() as db: with get_db() as db:
all_knowledge = (
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
)
user_ids = list(set(knowledge.user_id for knowledge in all_knowledge))
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
users_dict = {user.id: user for user in users}
knowledge_bases = [] knowledge_bases = []
for knowledge in all_knowledge: for knowledge in (
user = users_dict.get(knowledge.user_id) db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
):
user = Users.get_user_by_id(knowledge.user_id)
knowledge_bases.append( knowledge_bases.append(
KnowledgeUserModel.model_validate( KnowledgeUserModel.model_validate(
{ {
@ -207,147 +143,15 @@ class KnowledgeTable:
) )
return knowledge_bases return knowledge_bases
def search_knowledge_bases(
self, user_id: str, filter: dict, skip: int = 0, limit: int = 30
) -> KnowledgeListResponse:
try:
with get_db() as db:
query = db.query(Knowledge, User).outerjoin(
User, User.id == Knowledge.user_id
)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(
or_(
Knowledge.name.ilike(f"%{query_key}%"),
Knowledge.description.ilike(f"%{query_key}%"),
)
)
view_option = filter.get("view_option")
if view_option == "created":
query = query.filter(Knowledge.user_id == user_id)
elif view_option == "shared":
query = query.filter(Knowledge.user_id != user_id)
query = has_permission(db, Knowledge, query, filter)
query = query.order_by(Knowledge.updated_at.desc())
total = query.count()
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
items = query.all()
knowledge_bases = []
for knowledge_base, user in items:
knowledge_bases.append(
KnowledgeUserModel.model_validate(
{
**KnowledgeModel.model_validate(
knowledge_base
).model_dump(),
"user": (
UserModel.model_validate(user).model_dump()
if user
else None
),
}
)
)
return KnowledgeListResponse(items=knowledge_bases, total=total)
except Exception as e:
print(e)
return KnowledgeListResponse(items=[], total=0)
def search_knowledge_files(
self, filter: dict, skip: int = 0, limit: int = 30
) -> KnowledgeFileListResponse:
"""
Scalable version: search files across all knowledge bases the user has
READ access to, without loading all KBs or using large IN() lists.
"""
try:
with get_db() as db:
# Base query: join Knowledge → KnowledgeFile → File
query = (
db.query(File, User)
.join(KnowledgeFile, File.id == KnowledgeFile.file_id)
.join(Knowledge, KnowledgeFile.knowledge_id == Knowledge.id)
.outerjoin(User, User.id == KnowledgeFile.user_id)
)
# Apply access-control directly to the joined query
# This makes the database handle filtering, even with 10k+ KBs
query = has_permission(db, Knowledge, query, filter)
# Apply filename search
if filter:
q = filter.get("query")
if q:
query = query.filter(File.filename.ilike(f"%{q}%"))
# Order by file changes
query = query.order_by(File.updated_at.desc())
# Count before pagination
total = query.count()
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
rows = query.all()
items = []
for file, user in rows:
items.append(
FileUserResponse(
**FileModel.model_validate(file).model_dump(),
user=(
UserResponse(
**UserModel.model_validate(user).model_dump()
)
if user
else None
),
)
)
return KnowledgeFileListResponse(items=items, total=total)
except Exception as e:
print("search_knowledge_files error:", e)
return KnowledgeFileListResponse(items=[], total=0)
def check_access_by_user_id(self, id, user_id, permission="write") -> bool:
knowledge = self.get_knowledge_by_id(id)
if not knowledge:
return False
if knowledge.user_id == user_id:
return True
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
return has_access(user_id, permission, knowledge.access_control, user_group_ids)
def get_knowledge_bases_by_user_id( def get_knowledge_bases_by_user_id(
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write"
) -> list[KnowledgeUserModel]: ) -> list[KnowledgeUserModel]:
knowledge_bases = self.get_knowledge_bases() knowledge_bases = self.get_knowledge_bases()
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
return [ return [
knowledge_base knowledge_base
for knowledge_base in knowledge_bases for knowledge_base in knowledge_bases
if knowledge_base.user_id == user_id if knowledge_base.user_id == user_id
or has_access( or has_access(user_id, permission, knowledge_base.access_control)
user_id, permission, knowledge_base.access_control, user_group_ids
)
] ]
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]: def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
@ -358,197 +162,6 @@ class KnowledgeTable:
except Exception: except Exception:
return None return None
def get_knowledge_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[KnowledgeModel]:
knowledge = self.get_knowledge_by_id(id)
if not knowledge:
return None
if knowledge.user_id == user_id:
return knowledge
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
if has_access(user_id, "write", knowledge.access_control, user_group_ids):
return knowledge
return None
def get_knowledges_by_file_id(self, file_id: str) -> list[KnowledgeModel]:
try:
with get_db() as db:
knowledges = (
db.query(Knowledge)
.join(KnowledgeFile, Knowledge.id == KnowledgeFile.knowledge_id)
.filter(KnowledgeFile.file_id == file_id)
.all()
)
return [
KnowledgeModel.model_validate(knowledge) for knowledge in knowledges
]
except Exception:
return []
def search_files_by_id(
self,
knowledge_id: str,
user_id: str,
filter: dict,
skip: int = 0,
limit: int = 30,
) -> KnowledgeFileListResponse:
try:
with get_db() as db:
query = (
db.query(File, User)
.join(KnowledgeFile, File.id == KnowledgeFile.file_id)
.outerjoin(User, User.id == KnowledgeFile.user_id)
.filter(KnowledgeFile.knowledge_id == knowledge_id)
)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(or_(File.filename.ilike(f"%{query_key}%")))
view_option = filter.get("view_option")
if view_option == "created":
query = query.filter(KnowledgeFile.user_id == user_id)
elif view_option == "shared":
query = query.filter(KnowledgeFile.user_id != user_id)
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by == "name":
if direction == "asc":
query = query.order_by(File.filename.asc())
else:
query = query.order_by(File.filename.desc())
elif order_by == "created_at":
if direction == "asc":
query = query.order_by(File.created_at.asc())
else:
query = query.order_by(File.created_at.desc())
elif order_by == "updated_at":
if direction == "asc":
query = query.order_by(File.updated_at.asc())
else:
query = query.order_by(File.updated_at.desc())
else:
query = query.order_by(File.updated_at.desc())
else:
query = query.order_by(File.updated_at.desc())
# Count BEFORE pagination
total = query.count()
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
items = query.all()
files = []
for file, user in items:
files.append(
FileUserResponse(
**FileModel.model_validate(file).model_dump(),
user=(
UserResponse(
**UserModel.model_validate(user).model_dump()
)
if user
else None
),
)
)
return KnowledgeFileListResponse(items=files, total=total)
except Exception as e:
print(e)
return KnowledgeFileListResponse(items=[], total=0)
def get_files_by_id(self, knowledge_id: str) -> list[FileModel]:
try:
with get_db() as db:
files = (
db.query(File)
.join(KnowledgeFile, File.id == KnowledgeFile.file_id)
.filter(KnowledgeFile.knowledge_id == knowledge_id)
.all()
)
return [FileModel.model_validate(file) for file in files]
except Exception:
return []
def get_file_metadatas_by_id(self, knowledge_id: str) -> list[FileMetadataResponse]:
try:
with get_db() as db:
files = self.get_files_by_id(knowledge_id)
return [FileMetadataResponse(**file.model_dump()) for file in files]
except Exception:
return []
def add_file_to_knowledge_by_id(
self, knowledge_id: str, file_id: str, user_id: str
) -> Optional[KnowledgeFileModel]:
with get_db() as db:
knowledge_file = KnowledgeFileModel(
**{
"id": str(uuid.uuid4()),
"knowledge_id": knowledge_id,
"file_id": file_id,
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
result = KnowledgeFile(**knowledge_file.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return KnowledgeFileModel.model_validate(result)
else:
return None
except Exception:
return None
def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str) -> bool:
try:
with get_db() as db:
db.query(KnowledgeFile).filter_by(
knowledge_id=knowledge_id, file_id=file_id
).delete()
db.commit()
return True
except Exception:
return False
def reset_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
try:
with get_db() as db:
# Delete all knowledge_file entries for this knowledge_id
db.query(KnowledgeFile).filter_by(knowledge_id=id).delete()
db.commit()
# Update the knowledge entry's updated_at timestamp
db.query(Knowledge).filter_by(id=id).update(
{
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_knowledge_by_id(id=id)
except Exception as e:
log.exception(e)
return None
def update_knowledge_by_id( def update_knowledge_by_id(
self, id: str, form_data: KnowledgeForm, overwrite: bool = False self, id: str, form_data: KnowledgeForm, overwrite: bool = False
) -> Optional[KnowledgeModel]: ) -> Optional[KnowledgeModel]:

View file

@ -14,7 +14,7 @@ from sqlalchemy import BigInteger, Column, String, Text
class Memory(Base): class Memory(Base):
__tablename__ = "memory" __tablename__ = "memory"
id = Column(String, primary_key=True, unique=True) id = Column(String, primary_key=True)
user_id = Column(String) user_id = Column(String)
content = Column(Text) content = Column(Text)
updated_at = Column(BigInteger) updated_at = Column(BigInteger)
@ -71,13 +71,9 @@ class MemoriesTable:
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
with get_db() as db: with get_db() as db:
try: try:
memory = db.get(Memory, id) db.query(Memory).filter_by(id=id, user_id=user_id).update(
if not memory or memory.user_id != user_id: {"content": content, "updated_at": int(time.time())}
return None )
memory.content = content
memory.updated_at = int(time.time())
db.commit() db.commit()
return self.get_memory_by_id(id) return self.get_memory_by_id(id)
except Exception: except Exception:
@ -131,12 +127,7 @@ class MemoriesTable:
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
with get_db() as db: with get_db() as db:
try: try:
memory = db.get(Memory, id) db.query(Memory).filter_by(id=id, user_id=user_id).delete()
if not memory or memory.user_id != user_id:
return None
# Delete the memory
db.delete(memory)
db.commit() db.commit()
return True return True

View file

@ -5,11 +5,9 @@ from typing import Optional
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags from open_webui.models.tags import TagModel, Tag, Tags
from open_webui.models.users import Users, User, UserNameResponse
from open_webui.models.channels import Channels, ChannelMember
from pydantic import BaseModel, ConfigDict, field_validator from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists from sqlalchemy.sql import exists
@ -21,7 +19,7 @@ from sqlalchemy.sql import exists
class MessageReaction(Base): class MessageReaction(Base):
__tablename__ = "message_reaction" __tablename__ = "message_reaction"
id = Column(Text, primary_key=True, unique=True) id = Column(Text, primary_key=True)
user_id = Column(Text) user_id = Column(Text)
message_id = Column(Text) message_id = Column(Text)
name = Column(Text) name = Column(Text)
@ -40,19 +38,13 @@ class MessageReactionModel(BaseModel):
class Message(Base): class Message(Base):
__tablename__ = "message" __tablename__ = "message"
id = Column(Text, primary_key=True, unique=True) id = Column(Text, primary_key=True)
user_id = Column(Text) user_id = Column(Text)
channel_id = Column(Text, nullable=True) channel_id = Column(Text, nullable=True)
reply_to_id = Column(Text, nullable=True)
parent_id = Column(Text, nullable=True) parent_id = Column(Text, nullable=True)
# Pins
is_pinned = Column(Boolean, nullable=False, default=False)
pinned_at = Column(BigInteger, nullable=True)
pinned_by = Column(Text, nullable=True)
content = Column(Text) content = Column(Text)
data = Column(JSON, nullable=True) data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True) meta = Column(JSON, nullable=True)
@ -68,20 +60,14 @@ class MessageModel(BaseModel):
user_id: str user_id: str
channel_id: Optional[str] = None channel_id: Optional[str] = None
reply_to_id: Optional[str] = None
parent_id: Optional[str] = None parent_id: Optional[str] = None
# Pins
is_pinned: bool = False
pinned_by: Optional[str] = None
pinned_at: Optional[int] = None # timestamp in epoch (time_ns)
content: str content: str
data: Optional[dict] = None data: Optional[dict] = None
meta: Optional[dict] = None meta: Optional[dict] = None
created_at: int # timestamp in epoch (time_ns) created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch (time_ns) updated_at: int # timestamp in epoch
#################### ####################
@ -90,9 +76,7 @@ class MessageModel(BaseModel):
class MessageForm(BaseModel): class MessageForm(BaseModel):
temp_id: Optional[str] = None
content: str content: str
reply_to_id: Optional[str] = None
parent_id: Optional[str] = None parent_id: Optional[str] = None
data: Optional[dict] = None data: Optional[dict] = None
meta: Optional[dict] = None meta: Optional[dict] = None
@ -100,36 +84,11 @@ class MessageForm(BaseModel):
class Reactions(BaseModel): class Reactions(BaseModel):
name: str name: str
users: list[dict] user_ids: list[str]
count: int count: int
class MessageUserResponse(MessageModel): class MessageResponse(MessageModel):
user: Optional[UserNameResponse] = None
class MessageUserSlimResponse(MessageUserResponse):
data: bool | None = None
@field_validator("data", mode="before")
def convert_data_to_bool(cls, v):
# No data or not a dict → False
if not isinstance(v, dict):
return False
# True if ANY value in the dict is non-empty
return any(bool(val) for val in v.values())
class MessageReplyToResponse(MessageUserResponse):
reply_to_message: Optional[MessageUserSlimResponse] = None
class MessageWithReactionsResponse(MessageUserSlimResponse):
reactions: list[Reactions]
class MessageResponse(MessageReplyToResponse):
latest_reply_at: Optional[int] latest_reply_at: Optional[int]
reply_count: int reply_count: int
reactions: list[Reactions] reactions: list[Reactions]
@ -140,21 +99,15 @@ class MessageTable:
self, form_data: MessageForm, channel_id: str, user_id: str self, form_data: MessageForm, channel_id: str, user_id: str
) -> Optional[MessageModel]: ) -> Optional[MessageModel]:
with get_db() as db: with get_db() as db:
channel_member = Channels.join_channel(channel_id, user_id)
id = str(uuid.uuid4()) id = str(uuid.uuid4())
ts = int(time.time_ns())
ts = int(time.time_ns())
message = MessageModel( message = MessageModel(
**{ **{
"id": id, "id": id,
"user_id": user_id, "user_id": user_id,
"channel_id": channel_id, "channel_id": channel_id,
"reply_to_id": form_data.reply_to_id,
"parent_id": form_data.parent_id, "parent_id": form_data.parent_id,
"is_pinned": False,
"pinned_at": None,
"pinned_by": None,
"content": form_data.content, "content": form_data.content,
"data": form_data.data, "data": form_data.data,
"meta": form_data.meta, "meta": form_data.meta,
@ -162,8 +115,8 @@ class MessageTable:
"updated_at": ts, "updated_at": ts,
} }
) )
result = Message(**message.model_dump())
result = Message(**message.model_dump())
db.add(result) db.add(result)
db.commit() db.commit()
db.refresh(result) db.refresh(result)
@ -175,32 +128,19 @@ class MessageTable:
if not message: if not message:
return None return None
reply_to_message = (
self.get_message_by_id(message.reply_to_id)
if message.reply_to_id
else None
)
reactions = self.get_reactions_by_message_id(id) reactions = self.get_reactions_by_message_id(id)
thread_replies = self.get_thread_replies_by_message_id(id) replies = self.get_replies_by_message_id(id)
user = Users.get_user_by_id(message.user_id) return MessageResponse(
return MessageResponse.model_validate( **{
{
**MessageModel.model_validate(message).model_dump(), **MessageModel.model_validate(message).model_dump(),
"user": user.model_dump() if user else None, "latest_reply_at": replies[0].created_at if replies else None,
"reply_to_message": ( "reply_count": len(replies),
reply_to_message.model_dump() if reply_to_message else None
),
"latest_reply_at": (
thread_replies[0].created_at if thread_replies else None
),
"reply_count": len(thread_replies),
"reactions": reactions, "reactions": reactions,
} }
) )
def get_thread_replies_by_message_id(self, id: str) -> list[MessageReplyToResponse]: def get_replies_by_message_id(self, id: str) -> list[MessageModel]:
with get_db() as db: with get_db() as db:
all_messages = ( all_messages = (
db.query(Message) db.query(Message)
@ -208,27 +148,7 @@ class MessageTable:
.order_by(Message.created_at.desc()) .order_by(Message.created_at.desc())
.all() .all()
) )
return [MessageModel.model_validate(message) for message in all_messages]
messages = []
for message in all_messages:
reply_to_message = (
self.get_message_by_id(message.reply_to_id)
if message.reply_to_id
else None
)
messages.append(
MessageReplyToResponse.model_validate(
{
**MessageModel.model_validate(message).model_dump(),
"reply_to_message": (
reply_to_message.model_dump()
if reply_to_message
else None
),
}
)
)
return messages
def get_reply_user_ids_by_message_id(self, id: str) -> list[str]: def get_reply_user_ids_by_message_id(self, id: str) -> list[str]:
with get_db() as db: with get_db() as db:
@ -239,7 +159,7 @@ class MessageTable:
def get_messages_by_channel_id( def get_messages_by_channel_id(
self, channel_id: str, skip: int = 0, limit: int = 50 self, channel_id: str, skip: int = 0, limit: int = 50
) -> list[MessageReplyToResponse]: ) -> list[MessageModel]:
with get_db() as db: with get_db() as db:
all_messages = ( all_messages = (
db.query(Message) db.query(Message)
@ -249,31 +169,11 @@ class MessageTable:
.limit(limit) .limit(limit)
.all() .all()
) )
return [MessageModel.model_validate(message) for message in all_messages]
messages = []
for message in all_messages:
reply_to_message = (
self.get_message_by_id(message.reply_to_id)
if message.reply_to_id
else None
)
messages.append(
MessageReplyToResponse.model_validate(
{
**MessageModel.model_validate(message).model_dump(),
"reply_to_message": (
reply_to_message.model_dump()
if reply_to_message
else None
),
}
)
)
return messages
def get_messages_by_parent_id( def get_messages_by_parent_id(
self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50 self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50
) -> list[MessageReplyToResponse]: ) -> list[MessageModel]:
with get_db() as db: with get_db() as db:
message = db.get(Message, parent_id) message = db.get(Message, parent_id)
@ -293,49 +193,6 @@ class MessageTable:
if len(all_messages) < limit: if len(all_messages) < limit:
all_messages.append(message) all_messages.append(message)
messages = []
for message in all_messages:
reply_to_message = (
self.get_message_by_id(message.reply_to_id)
if message.reply_to_id
else None
)
messages.append(
MessageReplyToResponse.model_validate(
{
**MessageModel.model_validate(message).model_dump(),
"reply_to_message": (
reply_to_message.model_dump()
if reply_to_message
else None
),
}
)
)
return messages
def get_last_message_by_channel_id(self, channel_id: str) -> Optional[MessageModel]:
with get_db() as db:
message = (
db.query(Message)
.filter_by(channel_id=channel_id)
.order_by(Message.created_at.desc())
.first()
)
return MessageModel.model_validate(message) if message else None
def get_pinned_messages_by_channel_id(
self, channel_id: str, skip: int = 0, limit: int = 50
) -> list[MessageModel]:
with get_db() as db:
all_messages = (
db.query(Message)
.filter_by(channel_id=channel_id, is_pinned=True)
.order_by(Message.pinned_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return [MessageModel.model_validate(message) for message in all_messages] return [MessageModel.model_validate(message) for message in all_messages]
def update_message_by_id( def update_message_by_id(
@ -344,57 +201,17 @@ class MessageTable:
with get_db() as db: with get_db() as db:
message = db.get(Message, id) message = db.get(Message, id)
message.content = form_data.content message.content = form_data.content
message.data = { message.data = form_data.data
**(message.data if message.data else {}), message.meta = form_data.meta
**(form_data.data if form_data.data else {}),
}
message.meta = {
**(message.meta if message.meta else {}),
**(form_data.meta if form_data.meta else {}),
}
message.updated_at = int(time.time_ns()) message.updated_at = int(time.time_ns())
db.commit() db.commit()
db.refresh(message) db.refresh(message)
return MessageModel.model_validate(message) if message else None return MessageModel.model_validate(message) if message else None
def update_is_pinned_by_id(
self, id: str, is_pinned: bool, pinned_by: Optional[str] = None
) -> Optional[MessageModel]:
with get_db() as db:
message = db.get(Message, id)
message.is_pinned = is_pinned
message.pinned_at = int(time.time_ns()) if is_pinned else None
message.pinned_by = pinned_by if is_pinned else None
db.commit()
db.refresh(message)
return MessageModel.model_validate(message) if message else None
def get_unread_message_count(
self, channel_id: str, user_id: str, last_read_at: Optional[int] = None
) -> int:
with get_db() as db:
query = db.query(Message).filter(
Message.channel_id == channel_id,
Message.parent_id == None, # only count top-level messages
Message.created_at > (last_read_at if last_read_at else 0),
)
if user_id:
query = query.filter(Message.user_id != user_id)
return query.count()
def add_reaction_to_message( def add_reaction_to_message(
self, id: str, user_id: str, name: str self, id: str, user_id: str, name: str
) -> Optional[MessageReactionModel]: ) -> Optional[MessageReactionModel]:
with get_db() as db: with get_db() as db:
# check for existing reaction
existing_reaction = (
db.query(MessageReaction)
.filter_by(message_id=id, user_id=user_id, name=name)
.first()
)
if existing_reaction:
return MessageReactionModel.model_validate(existing_reaction)
reaction_id = str(uuid.uuid4()) reaction_id = str(uuid.uuid4())
reaction = MessageReactionModel( reaction = MessageReactionModel(
id=reaction_id, id=reaction_id,
@ -411,30 +228,17 @@ class MessageTable:
def get_reactions_by_message_id(self, id: str) -> list[Reactions]: def get_reactions_by_message_id(self, id: str) -> list[Reactions]:
with get_db() as db: with get_db() as db:
# JOIN User so all user info is fetched in one query all_reactions = db.query(MessageReaction).filter_by(message_id=id).all()
results = (
db.query(MessageReaction, User)
.join(User, MessageReaction.user_id == User.id)
.filter(MessageReaction.message_id == id)
.all()
)
reactions = {} reactions = {}
for reaction in all_reactions:
for reaction, user in results:
if reaction.name not in reactions: if reaction.name not in reactions:
reactions[reaction.name] = { reactions[reaction.name] = {
"name": reaction.name, "name": reaction.name,
"users": [], "user_ids": [],
"count": 0, "count": 0,
} }
reactions[reaction.name]["user_ids"].append(reaction.user_id)
reactions[reaction.name]["users"].append(
{
"id": user.id,
"name": user.name,
}
)
reactions[reaction.name]["count"] += 1 reactions[reaction.name]["count"] += 1
return [Reactions(**reaction) for reaction in reactions.values()] return [Reactions(**reaction) for reaction in reactions.values()]

View file

@ -3,17 +3,15 @@ import time
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.groups import Groups from open_webui.models.users import Users, UserResponse
from open_webui.models.users import User, UserModel, Users, UserResponse
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import String, cast, or_, and_, func from sqlalchemy import or_, and_, func
from sqlalchemy.dialects import postgresql, sqlite from sqlalchemy.dialects import postgresql, sqlite
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
@ -21,6 +19,7 @@ from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
@ -53,7 +52,7 @@ class ModelMeta(BaseModel):
class Model(Base): class Model(Base):
__tablename__ = "model" __tablename__ = "model"
id = Column(Text, primary_key=True, unique=True) id = Column(Text, primary_key=True)
""" """
The model's id as used in the API. If set to an existing model, it will override the model. The model's id as used in the API. If set to an existing model, it will override the model.
""" """
@ -133,11 +132,6 @@ class ModelResponse(ModelModel):
pass pass
class ModelListResponse(BaseModel):
items: list[ModelUserResponse]
total: int
class ModelForm(BaseModel): class ModelForm(BaseModel):
id: str id: str
base_model_id: Optional[str] = None base_model_id: Optional[str] = None
@ -181,16 +175,9 @@ class ModelsTable:
def get_models(self) -> list[ModelUserResponse]: def get_models(self) -> list[ModelUserResponse]:
with get_db() as db: with get_db() as db:
all_models = db.query(Model).filter(Model.base_model_id != None).all()
user_ids = list(set(model.user_id for model in all_models))
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
users_dict = {user.id: user for user in users}
models = [] models = []
for model in all_models: for model in db.query(Model).filter(Model.base_model_id != None).all():
user = users_dict.get(model.user_id) user = Users.get_user_by_id(model.user_id)
models.append( models.append(
ModelUserResponse.model_validate( ModelUserResponse.model_validate(
{ {
@ -212,143 +199,13 @@ class ModelsTable:
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write"
) -> list[ModelUserResponse]: ) -> list[ModelUserResponse]:
models = self.get_models() models = self.get_models()
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
return [ return [
model model
for model in models for model in models
if model.user_id == user_id if model.user_id == user_id
or has_access(user_id, permission, model.access_control, user_group_ids) or has_access(user_id, permission, model.access_control)
] ]
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
group_ids = filter.get("group_ids", [])
user_id = filter.get("user_id")
dialect_name = db.bind.dialect.name
# Public access
conditions = []
if group_ids or user_id:
conditions.extend(
[
Model.access_control.is_(None),
cast(Model.access_control, String) == "null",
]
)
# User-level permission
if user_id:
conditions.append(Model.user_id == user_id)
# Group-level permission
if group_ids:
group_conditions = []
for gid in group_ids:
if dialect_name == "sqlite":
group_conditions.append(
Model.access_control[permission]["group_ids"].contains([gid])
)
elif dialect_name == "postgresql":
group_conditions.append(
cast(
Model.access_control[permission]["group_ids"],
JSONB,
).contains([gid])
)
conditions.append(or_(*group_conditions))
if conditions:
query = query.filter(or_(*conditions))
return query
def search_models(
self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30
) -> ModelListResponse:
with get_db() as db:
# Join GroupMember so we can order by group_id when requested
query = db.query(Model, User).outerjoin(User, User.id == Model.user_id)
query = query.filter(Model.base_model_id != None)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(
or_(
Model.name.ilike(f"%{query_key}%"),
Model.base_model_id.ilike(f"%{query_key}%"),
)
)
view_option = filter.get("view_option")
if view_option == "created":
query = query.filter(Model.user_id == user_id)
elif view_option == "shared":
query = query.filter(Model.user_id != user_id)
# Apply access control filtering
query = self._has_permission(
db,
query,
filter,
permission="write",
)
tag = filter.get("tag")
if tag:
# TODO: This is a simple implementation and should be improved for performance
like_pattern = f'%"{tag.lower()}"%' # `"tag"` inside JSON array
meta_text = func.lower(cast(Model.meta, String))
query = query.filter(meta_text.like(like_pattern))
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by == "name":
if direction == "asc":
query = query.order_by(Model.name.asc())
else:
query = query.order_by(Model.name.desc())
elif order_by == "created_at":
if direction == "asc":
query = query.order_by(Model.created_at.asc())
else:
query = query.order_by(Model.created_at.desc())
elif order_by == "updated_at":
if direction == "asc":
query = query.order_by(Model.updated_at.asc())
else:
query = query.order_by(Model.updated_at.desc())
else:
query = query.order_by(Model.created_at.desc())
# Count BEFORE pagination
total = query.count()
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
items = query.all()
models = []
for model, user in items:
models.append(
ModelUserResponse(
**ModelModel.model_validate(model).model_dump(),
user=(
UserResponse(**UserModel.model_validate(user).model_dump())
if user
else None
),
)
)
return ModelListResponse(items=models, total=total)
def get_model_by_id(self, id: str) -> Optional[ModelModel]: def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try: try:
with get_db() as db: with get_db() as db:
@ -357,14 +214,6 @@ class ModelsTable:
except Exception: except Exception:
return None return None
def get_models_by_ids(self, ids: list[str]) -> list[ModelModel]:
try:
with get_db() as db:
models = db.query(Model).filter(Model.id.in_(ids)).all()
return [ModelModel.model_validate(model) for model in models]
except Exception:
return []
def toggle_model_by_id(self, id: str) -> Optional[ModelModel]: def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
with get_db() as db: with get_db() as db:
try: try:
@ -386,9 +235,11 @@ class ModelsTable:
try: try:
with get_db() as db: with get_db() as db:
# update only the fields that are present in the model # update only the fields that are present in the model
data = model.model_dump(exclude={"id"}) result = (
result = db.query(Model).filter_by(id=id).update(data) db.query(Model)
.filter_by(id=id)
.update(model.model_dump(exclude={"id"}))
)
db.commit() db.commit()
model = db.get(Model, id) model = db.get(Model, id)
@ -418,49 +269,5 @@ class ModelsTable:
except Exception: except Exception:
return False return False
def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]:
try:
with get_db() as db:
# Get existing models
existing_models = db.query(Model).all()
existing_ids = {model.id for model in existing_models}
# Prepare a set of new model IDs
new_model_ids = {model.id for model in models}
# Update or insert models
for model in models:
if model.id in existing_ids:
db.query(Model).filter_by(id=model.id).update(
{
**model.model_dump(),
"user_id": user_id,
"updated_at": int(time.time()),
}
)
else:
new_model = Model(
**{
**model.model_dump(),
"user_id": user_id,
"updated_at": int(time.time()),
}
)
db.add(new_model)
# Remove models that are no longer present
for model in existing_models:
if model.id not in new_model_ids:
db.delete(model)
db.commit()
return [
ModelModel.model_validate(model) for model in db.query(Model).all()
]
except Exception as e:
log.exception(f"Error syncing models for user {user_id}: {e}")
return []
Models = ModelsTable() Models = ModelsTable()

View file

@ -2,20 +2,15 @@ import json
import time import time
import uuid import uuid
from typing import Optional from typing import Optional
from functools import lru_cache
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, get_db
from open_webui.models.groups import Groups
from open_webui.utils.access_control import has_access from open_webui.utils.access_control import has_access
from open_webui.models.users import User, UserModel, Users, UserResponse from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy import or_, func, select, and_, text
from sqlalchemy import or_, func, select, and_, text, cast, or_, and_, func
from sqlalchemy.sql import exists from sqlalchemy.sql import exists
#################### ####################
@ -26,7 +21,7 @@ from sqlalchemy.sql import exists
class Note(Base): class Note(Base):
__tablename__ = "note" __tablename__ = "note"
id = Column(Text, primary_key=True, unique=True) id = Column(Text, primary_key=True)
user_id = Column(Text) user_id = Column(Text)
title = Column(Text) title = Column(Text)
@ -67,149 +62,11 @@ class NoteForm(BaseModel):
access_control: Optional[dict] = None access_control: Optional[dict] = None
class NoteUpdateForm(BaseModel):
title: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
class NoteUserResponse(NoteModel): class NoteUserResponse(NoteModel):
user: Optional[UserResponse] = None user: Optional[UserResponse] = None
class NoteItemResponse(BaseModel):
id: str
title: str
data: Optional[dict]
updated_at: int
created_at: int
user: Optional[UserResponse] = None
class NoteListResponse(BaseModel):
items: list[NoteUserResponse]
total: int
class NoteTable: class NoteTable:
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
group_ids = filter.get("group_ids", [])
user_id = filter.get("user_id")
dialect_name = db.bind.dialect.name
conditions = []
# Handle read_only permission separately
if permission == "read_only":
# For read_only, we want items where:
# 1. User has explicit read permission (via groups or user-level)
# 2. BUT does NOT have write permission
# 3. Public items are NOT considered read_only
read_conditions = []
# Group-level read permission
if group_ids:
group_read_conditions = []
for gid in group_ids:
if dialect_name == "sqlite":
group_read_conditions.append(
Note.access_control["read"]["group_ids"].contains([gid])
)
elif dialect_name == "postgresql":
group_read_conditions.append(
cast(
Note.access_control["read"]["group_ids"],
JSONB,
).contains([gid])
)
if group_read_conditions:
read_conditions.append(or_(*group_read_conditions))
# Combine read conditions
if read_conditions:
has_read = or_(*read_conditions)
else:
# If no read conditions, return empty result
return query.filter(False)
# Now exclude items where user has write permission
write_exclusions = []
# Exclude items owned by user (they have implicit write)
if user_id:
write_exclusions.append(Note.user_id != user_id)
# Exclude items where user has explicit write permission via groups
if group_ids:
group_write_conditions = []
for gid in group_ids:
if dialect_name == "sqlite":
group_write_conditions.append(
Note.access_control["write"]["group_ids"].contains([gid])
)
elif dialect_name == "postgresql":
group_write_conditions.append(
cast(
Note.access_control["write"]["group_ids"],
JSONB,
).contains([gid])
)
if group_write_conditions:
# User should NOT have write permission
write_exclusions.append(~or_(*group_write_conditions))
# Exclude public items (items without access_control)
write_exclusions.append(Note.access_control.isnot(None))
write_exclusions.append(cast(Note.access_control, String) != "null")
# Combine: has read AND does not have write AND not public
if write_exclusions:
query = query.filter(and_(has_read, *write_exclusions))
else:
query = query.filter(has_read)
return query
# Original logic for other permissions (read, write, etc.)
# Public access conditions
if group_ids or user_id:
conditions.extend(
[
Note.access_control.is_(None),
cast(Note.access_control, String) == "null",
]
)
# User-level permission (owner has all permissions)
if user_id:
conditions.append(Note.user_id == user_id)
# Group-level permission
if group_ids:
group_conditions = []
for gid in group_ids:
if dialect_name == "sqlite":
group_conditions.append(
Note.access_control[permission]["group_ids"].contains([gid])
)
elif dialect_name == "postgresql":
group_conditions.append(
cast(
Note.access_control[permission]["group_ids"],
JSONB,
).contains([gid])
)
conditions.append(or_(*group_conditions))
if conditions:
query = query.filter(or_(*conditions))
return query
def insert_new_note( def insert_new_note(
self, self,
form_data: NoteForm, form_data: NoteForm,
@ -232,153 +89,37 @@ class NoteTable:
db.commit() db.commit()
return note return note
def get_notes( def get_notes(self) -> list[NoteModel]:
self, skip: Optional[int] = None, limit: Optional[int] = None
) -> list[NoteModel]:
with get_db() as db: with get_db() as db:
query = db.query(Note).order_by(Note.updated_at.desc()) notes = db.query(Note).order_by(Note.updated_at.desc()).all()
if skip is not None:
query = query.offset(skip)
if limit is not None:
query = query.limit(limit)
notes = query.all()
return [NoteModel.model_validate(note) for note in notes] return [NoteModel.model_validate(note) for note in notes]
def search_notes(
self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30
) -> NoteListResponse:
with get_db() as db:
query = db.query(Note, User).outerjoin(User, User.id == Note.user_id)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(
or_(
Note.title.ilike(f"%{query_key}%"),
cast(Note.data["content"]["md"], Text).ilike(
f"%{query_key}%"
),
)
)
view_option = filter.get("view_option")
if view_option == "created":
query = query.filter(Note.user_id == user_id)
elif view_option == "shared":
query = query.filter(Note.user_id != user_id)
# Apply access control filtering
if "permission" in filter:
permission = filter["permission"]
else:
permission = "write"
query = self._has_permission(
db,
query,
filter,
permission=permission,
)
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by == "name":
if direction == "asc":
query = query.order_by(Note.title.asc())
else:
query = query.order_by(Note.title.desc())
elif order_by == "created_at":
if direction == "asc":
query = query.order_by(Note.created_at.asc())
else:
query = query.order_by(Note.created_at.desc())
elif order_by == "updated_at":
if direction == "asc":
query = query.order_by(Note.updated_at.asc())
else:
query = query.order_by(Note.updated_at.desc())
else:
query = query.order_by(Note.updated_at.desc())
else:
query = query.order_by(Note.updated_at.desc())
# Count BEFORE pagination
total = query.count()
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
items = query.all()
notes = []
for note, user in items:
notes.append(
NoteUserResponse(
**NoteModel.model_validate(note).model_dump(),
user=(
UserResponse(**UserModel.model_validate(user).model_dump())
if user
else None
),
)
)
return NoteListResponse(items=notes, total=total)
def get_notes_by_user_id( def get_notes_by_user_id(
self, self, user_id: str, permission: str = "write"
user_id: str,
permission: str = "read",
skip: Optional[int] = None,
limit: Optional[int] = None,
) -> list[NoteModel]: ) -> list[NoteModel]:
with get_db() as db: notes = self.get_notes()
user_group_ids = [ return [
group.id for group in Groups.get_groups_by_member_id(user_id) note
] for note in notes
if note.user_id == user_id
query = db.query(Note).order_by(Note.updated_at.desc()) or has_access(user_id, permission, note.access_control)
query = self._has_permission( ]
db, query, {"user_id": user_id, "group_ids": user_group_ids}, permission
)
if skip is not None:
query = query.offset(skip)
if limit is not None:
query = query.limit(limit)
notes = query.all()
return [NoteModel.model_validate(note) for note in notes]
def get_note_by_id(self, id: str) -> Optional[NoteModel]: def get_note_by_id(self, id: str) -> Optional[NoteModel]:
with get_db() as db: with get_db() as db:
note = db.query(Note).filter(Note.id == id).first() note = db.query(Note).filter(Note.id == id).first()
return NoteModel.model_validate(note) if note else None return NoteModel.model_validate(note) if note else None
def update_note_by_id( def update_note_by_id(self, id: str, form_data: NoteForm) -> Optional[NoteModel]:
self, id: str, form_data: NoteUpdateForm
) -> Optional[NoteModel]:
with get_db() as db: with get_db() as db:
note = db.query(Note).filter(Note.id == id).first() note = db.query(Note).filter(Note.id == id).first()
if not note: if not note:
return None return None
form_data = form_data.model_dump(exclude_unset=True) note.title = form_data.title
note.data = form_data.data
if "title" in form_data: note.meta = form_data.meta
note.title = form_data["title"] note.access_control = form_data.access_control
if "data" in form_data:
note.data = {**note.data, **form_data["data"]}
if "meta" in form_data:
note.meta = {**note.meta, **form_data["meta"]}
if "access_control" in form_data:
note.access_control = form_data["access_control"]
note.updated_at = int(time.time_ns()) note.updated_at = int(time.time_ns())
db.commit() db.commit()

View file

@ -1,276 +0,0 @@
import time
import logging
import uuid
from typing import Optional, List
import base64
import hashlib
import json
from cryptography.fernet import Fernet
from open_webui.internal.db import Base, get_db
from open_webui.env import OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, Index
log = logging.getLogger(__name__)
####################
# DB MODEL
####################
class OAuthSession(Base):
__tablename__ = "oauth_session"
id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text, nullable=False)
provider = Column(Text, nullable=False)
token = Column(
Text, nullable=False
) # JSON with access_token, id_token, refresh_token
expires_at = Column(BigInteger, nullable=False)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
# Add indexes for better performance
__table_args__ = (
Index("idx_oauth_session_user_id", "user_id"),
Index("idx_oauth_session_expires_at", "expires_at"),
Index("idx_oauth_session_user_provider", "user_id", "provider"),
)
class OAuthSessionModel(BaseModel):
id: str
user_id: str
provider: str
token: dict
expires_at: int # timestamp in epoch
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class OAuthSessionResponse(BaseModel):
id: str
user_id: str
provider: str
expires_at: int
class OAuthSessionTable:
def __init__(self):
self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
if not self.encryption_key:
raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set")
# check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes)
if len(self.encryption_key) != 44:
key_bytes = hashlib.sha256(self.encryption_key.encode()).digest()
self.encryption_key = base64.urlsafe_b64encode(key_bytes)
else:
self.encryption_key = self.encryption_key.encode()
try:
self.fernet = Fernet(self.encryption_key)
except Exception as e:
log.error(f"Error initializing Fernet with provided key: {e}")
raise
def _encrypt_token(self, token) -> str:
"""Encrypt OAuth tokens for storage"""
try:
token_json = json.dumps(token)
encrypted = self.fernet.encrypt(token_json.encode()).decode()
return encrypted
except Exception as e:
log.error(f"Error encrypting tokens: {e}")
raise
def _decrypt_token(self, token: str):
"""Decrypt OAuth tokens from storage"""
try:
decrypted = self.fernet.decrypt(token.encode()).decode()
return json.loads(decrypted)
except Exception as e:
log.error(f"Error decrypting tokens: {e}")
raise
def create_session(
self,
user_id: str,
provider: str,
token: dict,
) -> Optional[OAuthSessionModel]:
"""Create a new OAuth session"""
try:
with get_db() as db:
current_time = int(time.time())
id = str(uuid.uuid4())
result = OAuthSession(
**{
"id": id,
"user_id": user_id,
"provider": provider,
"token": self._encrypt_token(token),
"expires_at": token.get("expires_at"),
"created_at": current_time,
"updated_at": current_time,
}
)
db.add(result)
db.commit()
db.refresh(result)
if result:
result.token = token # Return decrypted token
return OAuthSessionModel.model_validate(result)
else:
return None
except Exception as e:
log.error(f"Error creating OAuth session: {e}")
return None
def get_session_by_id(self, session_id: str) -> Optional[OAuthSessionModel]:
"""Get OAuth session by ID"""
try:
with get_db() as db:
session = db.query(OAuthSession).filter_by(id=session_id).first()
if session:
session.token = self._decrypt_token(session.token)
return OAuthSessionModel.model_validate(session)
return None
except Exception as e:
log.error(f"Error getting OAuth session by ID: {e}")
return None
def get_session_by_id_and_user_id(
self, session_id: str, user_id: str
) -> Optional[OAuthSessionModel]:
"""Get OAuth session by ID and user ID"""
try:
with get_db() as db:
session = (
db.query(OAuthSession)
.filter_by(id=session_id, user_id=user_id)
.first()
)
if session:
session.token = self._decrypt_token(session.token)
return OAuthSessionModel.model_validate(session)
return None
except Exception as e:
log.error(f"Error getting OAuth session by ID: {e}")
return None
def get_session_by_provider_and_user_id(
self, provider: str, user_id: str
) -> Optional[OAuthSessionModel]:
"""Get OAuth session by provider and user ID"""
try:
with get_db() as db:
session = (
db.query(OAuthSession)
.filter_by(provider=provider, user_id=user_id)
.first()
)
if session:
session.token = self._decrypt_token(session.token)
return OAuthSessionModel.model_validate(session)
return None
except Exception as e:
log.error(f"Error getting OAuth session by provider and user ID: {e}")
return None
def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]:
"""Get all OAuth sessions for a user"""
try:
with get_db() as db:
sessions = db.query(OAuthSession).filter_by(user_id=user_id).all()
results = []
for session in sessions:
session.token = self._decrypt_token(session.token)
results.append(OAuthSessionModel.model_validate(session))
return results
except Exception as e:
log.error(f"Error getting OAuth sessions by user ID: {e}")
return []
def update_session_by_id(
self, session_id: str, token: dict
) -> Optional[OAuthSessionModel]:
"""Update OAuth session tokens"""
try:
with get_db() as db:
current_time = int(time.time())
db.query(OAuthSession).filter_by(id=session_id).update(
{
"token": self._encrypt_token(token),
"expires_at": token.get("expires_at"),
"updated_at": current_time,
}
)
db.commit()
session = db.query(OAuthSession).filter_by(id=session_id).first()
if session:
session.token = self._decrypt_token(session.token)
return OAuthSessionModel.model_validate(session)
return None
except Exception as e:
log.error(f"Error updating OAuth session tokens: {e}")
return None
def delete_session_by_id(self, session_id: str) -> bool:
"""Delete an OAuth session"""
try:
with get_db() as db:
result = db.query(OAuthSession).filter_by(id=session_id).delete()
db.commit()
return result > 0
except Exception as e:
log.error(f"Error deleting OAuth session: {e}")
return False
def delete_sessions_by_user_id(self, user_id: str) -> bool:
"""Delete all OAuth sessions for a user"""
try:
with get_db() as db:
result = db.query(OAuthSession).filter_by(user_id=user_id).delete()
db.commit()
return True
except Exception as e:
log.error(f"Error deleting OAuth sessions by user ID: {e}")
return False
def delete_sessions_by_provider(self, provider: str) -> bool:
"""Delete all OAuth sessions for a provider"""
try:
with get_db() as db:
db.query(OAuthSession).filter_by(provider=provider).delete()
db.commit()
return True
except Exception as e:
log.error(f"Error deleting OAuth sessions by provider {provider}: {e}")
return False
OAuthSessions = OAuthSessionTable()

View file

@ -2,7 +2,6 @@ import time
from typing import Optional from typing import Optional
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, get_db
from open_webui.models.groups import Groups
from open_webui.models.users import Users, UserResponse from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -104,16 +103,10 @@ class PromptsTable:
def get_prompts(self) -> list[PromptUserResponse]: def get_prompts(self) -> list[PromptUserResponse]:
with get_db() as db: with get_db() as db:
all_prompts = db.query(Prompt).order_by(Prompt.timestamp.desc()).all()
user_ids = list(set(prompt.user_id for prompt in all_prompts))
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
users_dict = {user.id: user for user in users}
prompts = [] prompts = []
for prompt in all_prompts:
user = users_dict.get(prompt.user_id) for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all():
user = Users.get_user_by_id(prompt.user_id)
prompts.append( prompts.append(
PromptUserResponse.model_validate( PromptUserResponse.model_validate(
{ {
@ -129,13 +122,12 @@ class PromptsTable:
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write"
) -> list[PromptUserResponse]: ) -> list[PromptUserResponse]:
prompts = self.get_prompts() prompts = self.get_prompts()
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
return [ return [
prompt prompt
for prompt in prompts for prompt in prompts
if prompt.user_id == user_id if prompt.user_id == user_id
or has_access(user_id, permission, prompt.access_control, user_group_ids) or has_access(user_id, permission, prompt.access_control)
] ]
def update_prompt_by_command( def update_prompt_by_command(

View file

@ -6,10 +6,12 @@ from typing import Optional
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint, Index from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
@ -22,11 +24,6 @@ class Tag(Base):
user_id = Column(String) user_id = Column(String)
meta = Column(JSON, nullable=True) meta = Column(JSON, nullable=True)
__table_args__ = (
PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),
Index("user_id_idx", "user_id"),
)
# Unique constraint ensuring (id, user_id) is unique, not just the `id` column # Unique constraint ensuring (id, user_id) is unique, not just the `id` column
__table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),) __table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),)

View file

@ -4,8 +4,7 @@ from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.users import Users, UserResponse from open_webui.models.users import Users, UserResponse
from open_webui.models.groups import Groups from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON from sqlalchemy import BigInteger, Column, String, Text, JSON
@ -13,6 +12,7 @@ from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
# Tools DB Schema # Tools DB Schema
@ -22,7 +22,7 @@ log = logging.getLogger(__name__)
class Tool(Base): class Tool(Base):
__tablename__ = "tool" __tablename__ = "tool"
id = Column(String, primary_key=True, unique=True) id = Column(String, primary_key=True)
user_id = Column(String) user_id = Column(String)
name = Column(Text) name = Column(Text)
content = Column(Text) content = Column(Text)
@ -93,8 +93,6 @@ class ToolResponse(BaseModel):
class ToolUserResponse(ToolResponse): class ToolUserResponse(ToolResponse):
user: Optional[UserResponse] = None user: Optional[UserResponse] = None
model_config = ConfigDict(extra="allow")
class ToolForm(BaseModel): class ToolForm(BaseModel):
id: str id: str
@ -146,16 +144,9 @@ class ToolsTable:
def get_tools(self) -> list[ToolUserModel]: def get_tools(self) -> list[ToolUserModel]:
with get_db() as db: with get_db() as db:
all_tools = db.query(Tool).order_by(Tool.updated_at.desc()).all()
user_ids = list(set(tool.user_id for tool in all_tools))
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
users_dict = {user.id: user for user in users}
tools = [] tools = []
for tool in all_tools: for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all():
user = users_dict.get(tool.user_id) user = Users.get_user_by_id(tool.user_id)
tools.append( tools.append(
ToolUserModel.model_validate( ToolUserModel.model_validate(
{ {
@ -170,13 +161,12 @@ class ToolsTable:
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write"
) -> list[ToolUserModel]: ) -> list[ToolUserModel]:
tools = self.get_tools() tools = self.get_tools()
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
return [ return [
tool tool
for tool in tools for tool in tools
if tool.user_id == user_id if tool.user_id == user_id
or has_access(user_id, permission, tool.access_control, user_group_ids) or has_access(user_id, permission, tool.access_control)
] ]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]: def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
@ -185,7 +175,7 @@ class ToolsTable:
tool = db.get(Tool, id) tool = db.get(Tool, id)
return tool.valves if tool.valves else {} return tool.valves if tool.valves else {}
except Exception as e: except Exception as e:
log.exception(f"Error getting tool valves by id {id}") log.exception(f"Error getting tool valves by id {id}: {e}")
return None return None
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:

View file

@ -4,139 +4,62 @@ from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
from open_webui.models.chats import Chats from open_webui.models.chats import Chats
from open_webui.models.groups import Groups, GroupMember from open_webui.models.groups import Groups
from open_webui.models.channels import ChannelMember
from open_webui.utils.misc import throttle
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import ( from sqlalchemy import BigInteger, Column, String, Text
BigInteger, from sqlalchemy import or_
JSON,
Column,
String,
Boolean,
Text,
Date,
exists,
select,
cast,
)
from sqlalchemy import or_, case
from sqlalchemy.dialects.postgresql import JSONB
import datetime
#################### ####################
# User DB Schema # User DB Schema
#################### ####################
class User(Base):
__tablename__ = "user"
id = Column(String, primary_key=True)
name = Column(String)
email = Column(String)
role = Column(String)
profile_image_url = Column(Text)
last_active_at = Column(BigInteger)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
api_key = Column(String, nullable=True, unique=True)
settings = Column(JSONField, nullable=True)
info = Column(JSONField, nullable=True)
oauth_sub = Column(Text, unique=True)
class UserSettings(BaseModel): class UserSettings(BaseModel):
ui: Optional[dict] = {} ui: Optional[dict] = {}
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
pass pass
class User(Base):
__tablename__ = "user"
id = Column(String, primary_key=True, unique=True)
email = Column(String)
username = Column(String(50), nullable=True)
role = Column(String)
name = Column(String)
profile_image_url = Column(Text)
profile_banner_image_url = Column(Text, nullable=True)
bio = Column(Text, nullable=True)
gender = Column(Text, nullable=True)
date_of_birth = Column(Date, nullable=True)
timezone = Column(String, nullable=True)
presence_state = Column(String, nullable=True)
status_emoji = Column(String, nullable=True)
status_message = Column(Text, nullable=True)
status_expires_at = Column(BigInteger, nullable=True)
info = Column(JSON, nullable=True)
settings = Column(JSON, nullable=True)
oauth = Column(JSON, nullable=True)
last_active_at = Column(BigInteger)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class UserModel(BaseModel): class UserModel(BaseModel):
id: str id: str
email: str
username: Optional[str] = None
role: str = "pending"
name: str name: str
email: str
role: str = "pending"
profile_image_url: str profile_image_url: str
profile_banner_image_url: Optional[str] = None
bio: Optional[str] = None
gender: Optional[str] = None
date_of_birth: Optional[datetime.date] = None
timezone: Optional[str] = None
presence_state: Optional[str] = None
status_emoji: Optional[str] = None
status_message: Optional[str] = None
status_expires_at: Optional[int] = None
info: Optional[dict] = None
settings: Optional[UserSettings] = None
oauth: Optional[dict] = None
last_active_at: int # timestamp in epoch last_active_at: int # timestamp in epoch
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True) api_key: Optional[str] = None
settings: Optional[UserSettings] = None
info: Optional[dict] = None
oauth_sub: Optional[str] = None
class UserStatusModel(UserModel):
is_active: bool = False
model_config = ConfigDict(from_attributes=True)
class ApiKey(Base):
__tablename__ = "api_key"
id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text, nullable=False)
key = Column(Text, unique=True, nullable=False)
data = Column(JSON, nullable=True)
expires_at = Column(BigInteger, nullable=True)
last_used_at = Column(BigInteger, nullable=True)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
class ApiKeyModel(BaseModel):
id: str
user_id: str
key: str
data: Optional[dict] = None
expires_at: Optional[int] = None
last_used_at: Optional[int] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@ -146,78 +69,23 @@ class ApiKeyModel(BaseModel):
#################### ####################
class UpdateProfileForm(BaseModel):
profile_image_url: str
name: str
bio: Optional[str] = None
gender: Optional[str] = None
date_of_birth: Optional[datetime.date] = None
class UserGroupIdsModel(UserModel):
group_ids: list[str] = []
class UserModelResponse(UserModel):
model_config = ConfigDict(extra="allow")
class UserListResponse(BaseModel): class UserListResponse(BaseModel):
users: list[UserModelResponse] users: list[UserModel]
total: int total: int
class UserGroupIdsListResponse(BaseModel): class UserResponse(BaseModel):
users: list[UserGroupIdsModel]
total: int
class UserStatus(BaseModel):
status_emoji: Optional[str] = None
status_message: Optional[str] = None
status_expires_at: Optional[int] = None
class UserInfoResponse(UserStatus):
id: str id: str
name: str name: str
email: str email: str
role: str role: str
profile_image_url: str
class UserIdNameResponse(BaseModel):
id: str
name: str
class UserIdNameStatusResponse(UserStatus):
id: str
name: str
is_active: Optional[bool] = None
class UserInfoListResponse(BaseModel):
users: list[UserInfoResponse]
total: int
class UserIdNameListResponse(BaseModel):
users: list[UserIdNameResponse]
total: int
class UserNameResponse(BaseModel): class UserNameResponse(BaseModel):
id: str id: str
name: str name: str
role: str role: str
class UserResponse(UserNameResponse):
email: str
class UserProfileImageResponse(UserNameResponse):
email: str
profile_image_url: str profile_image_url: str
@ -227,7 +95,6 @@ class UserRoleUpdateForm(BaseModel):
class UserUpdateForm(BaseModel): class UserUpdateForm(BaseModel):
role: str
name: str name: str
email: str email: str
profile_image_url: str profile_image_url: str
@ -242,20 +109,20 @@ class UsersTable:
email: str, email: str,
profile_image_url: str = "/user.png", profile_image_url: str = "/user.png",
role: str = "pending", role: str = "pending",
oauth: Optional[dict] = None, oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_db() as db: with get_db() as db:
user = UserModel( user = UserModel(
**{ **{
"id": id, "id": id,
"email": email,
"name": name, "name": name,
"email": email,
"role": role, "role": role,
"profile_image_url": profile_image_url, "profile_image_url": profile_image_url,
"last_active_at": int(time.time()), "last_active_at": int(time.time()),
"created_at": int(time.time()), "created_at": int(time.time()),
"updated_at": int(time.time()), "updated_at": int(time.time()),
"oauth": oauth, "oauth_sub": oauth_sub,
} }
) )
result = User(**user.model_dump()) result = User(**user.model_dump())
@ -278,13 +145,8 @@ class UsersTable:
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db() as db:
user = ( user = db.query(User).filter_by(api_key=api_key).first()
db.query(User) return UserModel.model_validate(user)
.join(ApiKey, User.id == ApiKey.user_id)
.filter(ApiKey.key == api_key)
.first()
)
return UserModel.model_validate(user) if user else None
except Exception: except Exception:
return None return None
@ -296,23 +158,12 @@ class UsersTable:
except Exception: except Exception:
return None return None
def get_user_by_oauth_sub(self, provider: str, sub: str) -> Optional[UserModel]: def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try: try:
with get_db() as db: # type: Session with get_db() as db:
dialect_name = db.bind.dialect.name user = db.query(User).filter_by(oauth_sub=sub).first()
return UserModel.model_validate(user)
query = db.query(User) except Exception:
if dialect_name == "sqlite":
query = query.filter(User.oauth.contains({provider: {"sub": sub}}))
elif dialect_name == "postgresql":
query = query.filter(
User.oauth[provider].cast(JSONB)["sub"].astext == sub
)
user = query.first()
return UserModel.model_validate(user) if user else None
except Exception as e:
# You may want to log the exception here
return None return None
def get_users( def get_users(
@ -320,9 +171,8 @@ class UsersTable:
filter: Optional[dict] = None, filter: Optional[dict] = None,
skip: Optional[int] = None, skip: Optional[int] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
) -> dict: ) -> UserListResponse:
with get_db() as db: with get_db() as db:
# Join GroupMember so we can order by group_id when requested
query = db.query(User) query = db.query(User)
if filter: if filter:
@ -335,76 +185,14 @@ class UsersTable:
) )
) )
channel_id = filter.get("channel_id")
if channel_id:
query = query.filter(
exists(
select(ChannelMember.id).where(
ChannelMember.user_id == User.id,
ChannelMember.channel_id == channel_id,
)
)
)
user_ids = filter.get("user_ids")
group_ids = filter.get("group_ids")
if isinstance(user_ids, list) and isinstance(group_ids, list):
# If both are empty lists, return no users
if not user_ids and not group_ids:
return {"users": [], "total": 0}
if user_ids:
query = query.filter(User.id.in_(user_ids))
if group_ids:
query = query.filter(
exists(
select(GroupMember.id).where(
GroupMember.user_id == User.id,
GroupMember.group_id.in_(group_ids),
)
)
)
roles = filter.get("roles")
if roles:
include_roles = [role for role in roles if not role.startswith("!")]
exclude_roles = [role[1:] for role in roles if role.startswith("!")]
if include_roles:
query = query.filter(User.role.in_(include_roles))
if exclude_roles:
query = query.filter(~User.role.in_(exclude_roles))
order_by = filter.get("order_by") order_by = filter.get("order_by")
direction = filter.get("direction") direction = filter.get("direction")
if order_by and order_by.startswith("group_id:"): if order_by == "name":
group_id = order_by.split(":", 1)[1]
# Subquery that checks if the user belongs to the group
membership_exists = exists(
select(GroupMember.id).where(
GroupMember.user_id == User.id,
GroupMember.group_id == group_id,
)
)
# CASE: user in group → 1, user not in group → 0
group_sort = case((membership_exists, 1), else_=0)
if direction == "asc":
query = query.order_by(group_sort.asc(), User.name.asc())
else:
query = query.order_by(group_sort.desc(), User.name.asc())
elif order_by == "name":
if direction == "asc": if direction == "asc":
query = query.order_by(User.name.asc()) query = query.order_by(User.name.asc())
else: else:
query = query.order_by(User.name.desc()) query = query.order_by(User.name.desc())
elif order_by == "email": elif order_by == "email":
if direction == "asc": if direction == "asc":
query = query.order_by(User.email.asc()) query = query.order_by(User.email.asc())
@ -437,32 +225,18 @@ class UsersTable:
else: else:
query = query.order_by(User.created_at.desc()) query = query.order_by(User.created_at.desc())
# Count BEFORE pagination if skip:
total = query.count()
# correct pagination logic
if skip is not None:
query = query.offset(skip) query = query.offset(skip)
if limit is not None: if limit:
query = query.limit(limit) query = query.limit(limit)
users = query.all() users = query.all()
return { return {
"users": [UserModel.model_validate(user) for user in users], "users": [UserModel.model_validate(user) for user in users],
"total": total, "total": db.query(User).count(),
} }
def get_users_by_group_id(self, group_id: str) -> list[UserModel]: def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
with get_db() as db:
users = (
db.query(User)
.join(GroupMember, User.id == GroupMember.user_id)
.filter(GroupMember.group_id == group_id)
.all()
)
return [UserModel.model_validate(user) for user in users]
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserStatusModel]:
with get_db() as db: with get_db() as db:
users = db.query(User).filter(User.id.in_(user_ids)).all() users = db.query(User).filter(User.id.in_(user_ids)).all()
return [UserModel.model_validate(user) for user in users] return [UserModel.model_validate(user) for user in users]
@ -471,10 +245,6 @@ class UsersTable:
with get_db() as db: with get_db() as db:
return db.query(User).count() return db.query(User).count()
def has_users(self) -> bool:
with get_db() as db:
return db.query(db.query(User).exists()).scalar()
def get_first_user(self) -> UserModel: def get_first_user(self) -> UserModel:
try: try:
with get_db() as db: with get_db() as db:
@ -499,15 +269,6 @@ class UsersTable:
except Exception: except Exception:
return None return None
def get_num_users_active_today(self) -> Optional[int]:
with get_db() as db:
current_timestamp = int(datetime.datetime.now().timestamp())
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
query = db.query(User).filter(
User.last_active_at > today_midnight_timestamp
)
return query.count()
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db() as db:
@ -518,21 +279,6 @@ class UsersTable:
except Exception: except Exception:
return None return None
def update_user_status_by_id(
self, id: str, form_data: UserStatus
) -> Optional[UserModel]:
try:
with get_db() as db:
db.query(User).filter_by(id=id).update(
{**form_data.model_dump(exclude_none=True)}
)
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception:
return None
def update_user_profile_image_url_by_id( def update_user_profile_image_url_by_id(
self, id: str, profile_image_url: str self, id: str, profile_image_url: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
@ -548,8 +294,7 @@ class UsersTable:
except Exception: except Exception:
return None return None
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL) def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
def update_last_active_by_id(self, id: str) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db() as db:
db.query(User).filter_by(id=id).update( db.query(User).filter_by(id=id).update(
@ -562,35 +307,16 @@ class UsersTable:
except Exception: except Exception:
return None return None
def update_user_oauth_by_id( def update_user_oauth_sub_by_id(
self, id: str, provider: str, sub: str self, id: str, oauth_sub: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
"""
Update or insert an OAuth provider/sub pair into the user's oauth JSON field.
Example resulting structure:
{
"google": { "sub": "123" },
"github": { "sub": "abc" }
}
"""
try: try:
with get_db() as db: with get_db() as db:
user = db.query(User).filter_by(id=id).first() db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
if not user:
return None
# Load existing oauth JSON or create empty
oauth = user.oauth or {}
# Update or insert provider entry
oauth[provider] = {"sub": sub}
# Persist updated JSON
db.query(User).filter_by(id=id).update({"oauth": oauth})
db.commit() db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
return None return None
@ -603,8 +329,7 @@ class UsersTable:
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
# return UserModel(**user.dict()) # return UserModel(**user.dict())
except Exception as e: except Exception:
print(e)
return None return None
def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]: def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
@ -644,45 +369,23 @@ class UsersTable:
except Exception: except Exception:
return False return False
def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
try:
with get_db() as db:
result = db.query(User).filter_by(id=id).update({"api_key": api_key})
db.commit()
return True if result == 1 else False
except Exception:
return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]: def get_user_api_key_by_id(self, id: str) -> Optional[str]:
try: try:
with get_db() as db: with get_db() as db:
api_key = db.query(ApiKey).filter_by(user_id=id).first() user = db.query(User).filter_by(id=id).first()
return api_key.key if api_key else None return user.api_key
except Exception: except Exception:
return None return None
def update_user_api_key_by_id(self, id: str, api_key: str) -> bool:
try:
with get_db() as db:
db.query(ApiKey).filter_by(user_id=id).delete()
db.commit()
now = int(time.time())
new_api_key = ApiKey(
id=f"key_{id}",
user_id=id,
key=api_key,
created_at=now,
updated_at=now,
)
db.add(new_api_key)
db.commit()
return True
except Exception:
return False
def delete_user_api_key_by_id(self, id: str) -> bool:
try:
with get_db() as db:
db.query(ApiKey).filter_by(user_id=id).delete()
db.commit()
return True
except Exception:
return False
def get_valid_user_ids(self, user_ids: list[str]) -> list[str]: def get_valid_user_ids(self, user_ids: list[str]) -> list[str]:
with get_db() as db: with get_db() as db:
users = db.query(User).filter(User.id.in_(user_ids)).all() users = db.query(User).filter(User.id.in_(user_ids)).all()
@ -696,23 +399,5 @@ class UsersTable:
else: else:
return None return None
def get_active_user_count(self) -> int:
with get_db() as db:
# Consider user active if last_active_at within the last 3 minutes
three_minutes_ago = int(time.time()) - 180
count = (
db.query(User).filter(User.last_active_at >= three_minutes_ago).count()
)
return count
def is_user_active(self, user_id: str) -> bool:
with get_db() as db:
user = db.query(User).filter_by(id=user_id).first()
if user and user.last_active_at:
# Consider user active if last_active_at within the last 3 minutes
three_minutes_ago = int(time.time()) - 180
return user.last_active_at >= three_minutes_ago
return False
Users = UsersTable() Users = UsersTable()

View file

@ -15,28 +15,24 @@ class DatalabMarkerLoader:
self, self,
file_path: str, file_path: str,
api_key: str, api_key: str,
api_base_url: str, langs: Optional[str] = None,
additional_config: Optional[str] = None,
use_llm: bool = False, use_llm: bool = False,
skip_cache: bool = False, skip_cache: bool = False,
force_ocr: bool = False, force_ocr: bool = False,
paginate: bool = False, paginate: bool = False,
strip_existing_ocr: bool = False, strip_existing_ocr: bool = False,
disable_image_extraction: bool = False, disable_image_extraction: bool = False,
format_lines: bool = False,
output_format: str = None, output_format: str = None,
): ):
self.file_path = file_path self.file_path = file_path
self.api_key = api_key self.api_key = api_key
self.api_base_url = api_base_url self.langs = langs
self.additional_config = additional_config
self.use_llm = use_llm self.use_llm = use_llm
self.skip_cache = skip_cache self.skip_cache = skip_cache
self.force_ocr = force_ocr self.force_ocr = force_ocr
self.paginate = paginate self.paginate = paginate
self.strip_existing_ocr = strip_existing_ocr self.strip_existing_ocr = strip_existing_ocr
self.disable_image_extraction = disable_image_extraction self.disable_image_extraction = disable_image_extraction
self.format_lines = format_lines
self.output_format = output_format self.output_format = output_format
def _get_mime_type(self, filename: str) -> str: def _get_mime_type(self, filename: str) -> str:
@ -64,7 +60,7 @@ class DatalabMarkerLoader:
return mime_map.get(ext, "application/octet-stream") return mime_map.get(ext, "application/octet-stream")
def check_marker_request_status(self, request_id: str) -> dict: def check_marker_request_status(self, request_id: str) -> dict:
url = f"{self.api_base_url}/{request_id}" url = f"https://www.datalab.to/api/v1/marker/{request_id}"
headers = {"X-Api-Key": self.api_key} headers = {"X-Api-Key": self.api_key}
try: try:
response = requests.get(url, headers=headers) response = requests.get(url, headers=headers)
@ -85,24 +81,22 @@ class DatalabMarkerLoader:
) )
def load(self) -> List[Document]: def load(self) -> List[Document]:
url = "https://www.datalab.to/api/v1/marker"
filename = os.path.basename(self.file_path) filename = os.path.basename(self.file_path)
mime_type = self._get_mime_type(filename) mime_type = self._get_mime_type(filename)
headers = {"X-Api-Key": self.api_key} headers = {"X-Api-Key": self.api_key}
form_data = { form_data = {
"langs": self.langs,
"use_llm": str(self.use_llm).lower(), "use_llm": str(self.use_llm).lower(),
"skip_cache": str(self.skip_cache).lower(), "skip_cache": str(self.skip_cache).lower(),
"force_ocr": str(self.force_ocr).lower(), "force_ocr": str(self.force_ocr).lower(),
"paginate": str(self.paginate).lower(), "paginate": str(self.paginate).lower(),
"strip_existing_ocr": str(self.strip_existing_ocr).lower(), "strip_existing_ocr": str(self.strip_existing_ocr).lower(),
"disable_image_extraction": str(self.disable_image_extraction).lower(), "disable_image_extraction": str(self.disable_image_extraction).lower(),
"format_lines": str(self.format_lines).lower(),
"output_format": self.output_format, "output_format": self.output_format,
} }
if self.additional_config and self.additional_config.strip():
form_data["additional_config"] = self.additional_config
log.info( log.info(
f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}" f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}"
) )
@ -111,10 +105,7 @@ class DatalabMarkerLoader:
with open(self.file_path, "rb") as f: with open(self.file_path, "rb") as f:
files = {"file": (filename, f, mime_type)} files = {"file": (filename, f, mime_type)}
response = requests.post( response = requests.post(
f"{self.api_base_url}", url, data=form_data, files=files, headers=headers
data=form_data,
files=files,
headers=headers,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@ -142,92 +133,74 @@ class DatalabMarkerLoader:
check_url = result.get("request_check_url") check_url = result.get("request_check_url")
request_id = result.get("request_id") request_id = result.get("request_id")
if not check_url:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail="No request_check_url returned."
)
# Check if this is a direct response (self-hosted) or polling response (DataLab) for _ in range(300): # Up to 10 minutes
if check_url: time.sleep(2)
# DataLab polling pattern try:
for _ in range(300): # Up to 10 minutes poll_response = requests.get(check_url, headers=headers)
time.sleep(2) poll_response.raise_for_status()
try: poll_result = poll_response.json()
poll_response = requests.get(check_url, headers=headers) except (requests.HTTPError, ValueError) as e:
poll_response.raise_for_status() raw_body = poll_response.text
poll_result = poll_response.json() log.error(f"Polling error: {e}, response body: {raw_body}")
except (requests.HTTPError, ValueError) as e:
raw_body = poll_response.text
log.error(f"Polling error: {e}, response body: {raw_body}")
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}"
)
status_val = poll_result.get("status")
success_val = poll_result.get("success")
if status_val == "complete":
summary = {
k: poll_result.get(k)
for k in (
"status",
"output_format",
"success",
"error",
"page_count",
"total_cost",
)
}
log.info(
f"Marker processing completed successfully: {json.dumps(summary, indent=2)}"
)
break
if status_val == "failed" or success_val is False:
log.error(
f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}"
)
error_msg = (
poll_result.get("error")
or "Marker returned failure without error message"
)
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Marker processing failed: {error_msg}",
)
else:
raise HTTPException( raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT, status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}"
detail="Marker processing timed out",
) )
if not poll_result.get("success", False): status_val = poll_result.get("status")
error_msg = poll_result.get("error") or "Unknown processing error" success_val = poll_result.get("success")
if status_val == "complete":
summary = {
k: poll_result.get(k)
for k in (
"status",
"output_format",
"success",
"error",
"page_count",
"total_cost",
)
}
log.info(
f"Marker processing completed successfully: {json.dumps(summary, indent=2)}"
)
break
if status_val == "failed" or success_val is False:
log.error(
f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}"
)
error_msg = (
poll_result.get("error")
or "Marker returned failure without error message"
)
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, status.HTTP_400_BAD_REQUEST,
detail=f"Final processing failed: {error_msg}", detail=f"Marker processing failed: {error_msg}",
) )
# DataLab format - content in format-specific fields
content_key = self.output_format.lower()
raw_content = poll_result.get(content_key)
final_result = poll_result
else: else:
# Self-hosted direct response - content in "output" field raise HTTPException(
if "output" in result: status.HTTP_504_GATEWAY_TIMEOUT, detail="Marker processing timed out"
log.info("Self-hosted Marker returned direct response without polling") )
raw_content = result.get("output")
final_result = result
else:
available_fields = (
list(result.keys())
if isinstance(result, dict)
else "non-dict response"
)
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Custom Marker endpoint returned success but no 'output' field found. Available fields: {available_fields}. Expected either 'request_check_url' for polling or 'output' field for direct response.",
)
if self.output_format.lower() == "json": if not poll_result.get("success", False):
error_msg = poll_result.get("error") or "Unknown processing error"
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Final processing failed: {error_msg}",
)
content_key = self.output_format.lower()
raw_content = poll_result.get(content_key)
if content_key == "json":
full_text = json.dumps(raw_content, indent=2) full_text = json.dumps(raw_content, indent=2)
elif self.output_format.lower() in {"markdown", "html"}: elif content_key in {"markdown", "html"}:
full_text = str(raw_content).strip() full_text = str(raw_content).strip()
else: else:
raise HTTPException( raise HTTPException(
@ -238,14 +211,14 @@ class DatalabMarkerLoader:
if not full_text: if not full_text:
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, status.HTTP_400_BAD_REQUEST,
detail="Marker returned empty content", detail="Datalab Marker returned empty content",
) )
marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output") marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output")
os.makedirs(marker_output_dir, exist_ok=True) os.makedirs(marker_output_dir, exist_ok=True)
file_ext_map = {"markdown": "md", "json": "json", "html": "html"} file_ext_map = {"markdown": "md", "json": "json", "html": "html"}
file_ext = file_ext_map.get(self.output_format.lower(), "txt") file_ext = file_ext_map.get(content_key, "txt")
output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}" output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}"
output_path = os.path.join(marker_output_dir, output_filename) output_path = os.path.join(marker_output_dir, output_filename)
@ -258,13 +231,13 @@ class DatalabMarkerLoader:
metadata = { metadata = {
"source": filename, "source": filename,
"output_format": final_result.get("output_format", self.output_format), "output_format": poll_result.get("output_format", self.output_format),
"page_count": final_result.get("page_count", 0), "page_count": poll_result.get("page_count", 0),
"processed_with_llm": self.use_llm, "processed_with_llm": self.use_llm,
"request_id": request_id or "", "request_id": request_id or "",
} }
images = final_result.get("images", {}) images = poll_result.get("images", {})
if images: if images:
metadata["image_count"] = len(images) metadata["image_count"] = len(images)
metadata["images"] = json.dumps(list(images.keys())) metadata["images"] = json.dumps(list(images.keys()))

View file

@ -1,13 +1,13 @@
import requests import requests
import logging, os import logging
from typing import Iterator, List, Union from typing import Iterator, List, Union
from urllib.parse import quote
from langchain_core.document_loaders import BaseLoader from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document from langchain_core.documents import Document
from open_webui.utils.headers import include_user_info_headers from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ExternalDocumentLoader(BaseLoader): class ExternalDocumentLoader(BaseLoader):
@ -17,7 +17,6 @@ class ExternalDocumentLoader(BaseLoader):
url: str, url: str,
api_key: str, api_key: str,
mime_type=None, mime_type=None,
user=None,
**kwargs, **kwargs,
) -> None: ) -> None:
self.url = url self.url = url
@ -26,9 +25,7 @@ class ExternalDocumentLoader(BaseLoader):
self.file_path = file_path self.file_path = file_path
self.mime_type = mime_type self.mime_type = mime_type
self.user = user def load(self) -> list[Document]:
def load(self) -> List[Document]:
with open(self.file_path, "rb") as f: with open(self.file_path, "rb") as f:
data = f.read() data = f.read()
@ -39,51 +36,23 @@ class ExternalDocumentLoader(BaseLoader):
if self.api_key is not None: if self.api_key is not None:
headers["Authorization"] = f"Bearer {self.api_key}" headers["Authorization"] = f"Bearer {self.api_key}"
try:
headers["X-Filename"] = quote(os.path.basename(self.file_path))
except:
pass
if self.user is not None:
headers = include_user_info_headers(headers, self.user)
url = self.url url = self.url
if url.endswith("/"): if url.endswith("/"):
url = url[:-1] url = url[:-1]
try: r = requests.put(f"{url}/process", data=data, headers=headers)
response = requests.put(f"{url}/process", data=data, headers=headers)
except Exception as e:
log.error(f"Error connecting to endpoint: {e}")
raise Exception(f"Error connecting to endpoint: {e}")
if response.ok: if r.ok:
res = r.json()
response_data = response.json()
if response_data:
if isinstance(response_data, dict):
return [
Document(
page_content=response_data.get("page_content"),
metadata=response_data.get("metadata"),
)
]
elif isinstance(response_data, list):
documents = []
for document in response_data:
documents.append(
Document(
page_content=document.get("page_content"),
metadata=document.get("metadata"),
)
)
return documents
else:
raise Exception("Error loading document: Unable to parse content")
if res:
return [
Document(
page_content=res.get("page_content"),
metadata=res.get("metadata"),
)
]
else: else:
raise Exception("Error loading document: No content returned") raise Exception("Error loading document: No content returned")
else: else:
raise Exception( raise Exception(f"Error loading document: {r.status_code} {r.text}")
f"Error loading document: {response.status_code} {response.text}"
)

View file

@ -4,8 +4,10 @@ from typing import Iterator, List, Union
from langchain_core.document_loaders import BaseLoader from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ExternalWebLoader(BaseLoader): class ExternalWebLoader(BaseLoader):

View file

@ -2,9 +2,7 @@ import requests
import logging import logging
import ftfy import ftfy
import sys import sys
import json
from azure.identity import DefaultAzureCredential
from langchain_community.document_loaders import ( from langchain_community.document_loaders import (
AzureAIDocumentIntelligenceLoader, AzureAIDocumentIntelligenceLoader,
BSHTMLLoader, BSHTMLLoader,
@ -15,7 +13,7 @@ from langchain_community.document_loaders import (
TextLoader, TextLoader,
UnstructuredEPubLoader, UnstructuredEPubLoader,
UnstructuredExcelLoader, UnstructuredExcelLoader,
UnstructuredODTLoader, UnstructuredMarkdownLoader,
UnstructuredPowerPointLoader, UnstructuredPowerPointLoader,
UnstructuredRSTLoader, UnstructuredRSTLoader,
UnstructuredXMLLoader, UnstructuredXMLLoader,
@ -27,13 +25,13 @@ from open_webui.retrieval.loaders.external_document import ExternalDocumentLoade
from open_webui.retrieval.loaders.mistral import MistralLoader from open_webui.retrieval.loaders.mistral import MistralLoader
from open_webui.retrieval.loaders.datalab_marker import DatalabMarkerLoader from open_webui.retrieval.loaders.datalab_marker import DatalabMarkerLoader
from open_webui.retrieval.loaders.mineru import MinerULoader
from open_webui.env import GLOBAL_LOG_LEVEL from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
known_source_ext = [ known_source_ext = [
"go", "go",
@ -78,6 +76,7 @@ known_source_ext = [
"swift", "swift",
"vue", "vue",
"svelte", "svelte",
"msg",
"ex", "ex",
"exs", "exs",
"erl", "erl",
@ -131,9 +130,8 @@ class TikaLoader:
class DoclingLoader: class DoclingLoader:
def __init__(self, url, api_key=None, file_path=None, mime_type=None, params=None): def __init__(self, url, file_path=None, mime_type=None, params=None):
self.url = url.rstrip("/") self.url = url.rstrip("/")
self.api_key = api_key
self.file_path = file_path self.file_path = file_path
self.mime_type = mime_type self.mime_type = mime_type
@ -141,25 +139,36 @@ class DoclingLoader:
def load(self) -> list[Document]: def load(self) -> list[Document]:
with open(self.file_path, "rb") as f: with open(self.file_path, "rb") as f:
headers = {} files = {
if self.api_key: "files": (
headers["X-Api-Key"] = f"Bearer {self.api_key}" self.file_path,
f,
self.mime_type or "application/octet-stream",
)
}
r = requests.post( params = {
f"{self.url}/v1/convert/file", "image_export_mode": "placeholder",
files={ "table_mode": "accurate",
"files": ( }
self.file_path,
f, if self.params:
self.mime_type or "application/octet-stream", if self.params.get("do_picture_classification"):
params["do_picture_classification"] = self.params.get(
"do_picture_classification"
) )
},
data={ if self.params.get("ocr_engine") and self.params.get("ocr_lang"):
"image_export_mode": "placeholder", params["ocr_engine"] = self.params.get("ocr_engine")
**self.params, params["ocr_lang"] = [
}, lang.strip()
headers=headers, for lang in self.params.get("ocr_lang").split(",")
) if lang.strip()
]
endpoint = f"{self.url}/v1alpha/convert/file"
r = requests.post(endpoint, files=files, data=params)
if r.ok: if r.ok:
result = r.json() result = r.json()
document_data = result.get("document", {}) document_data = result.get("document", {})
@ -168,6 +177,7 @@ class DoclingLoader:
metadata = {"Content-Type": self.mime_type} if self.mime_type else {} metadata = {"Content-Type": self.mime_type} if self.mime_type else {}
log.debug("Docling extracted text: %s", text) log.debug("Docling extracted text: %s", text)
return [Document(page_content=text, metadata=metadata)] return [Document(page_content=text, metadata=metadata)]
else: else:
error_msg = f"Error calling Docling API: {r.reason}" error_msg = f"Error calling Docling API: {r.reason}"
@ -184,7 +194,6 @@ class DoclingLoader:
class Loader: class Loader:
def __init__(self, engine: str = "", **kwargs): def __init__(self, engine: str = "", **kwargs):
self.engine = engine self.engine = engine
self.user = kwargs.get("user", None)
self.kwargs = kwargs self.kwargs = kwargs
def load( def load(
@ -202,10 +211,7 @@ class Loader:
def _is_text_file(self, file_ext: str, file_content_type: str) -> bool: def _is_text_file(self, file_ext: str, file_content_type: str) -> bool:
return file_ext in known_source_ext or ( return file_ext in known_source_ext or (
file_content_type file_content_type and file_content_type.find("text/") >= 0
and file_content_type.find("text/") >= 0
# Avoid text/html files being detected as text
and not file_content_type.find("html") >= 0
) )
def _get_loader(self, filename: str, file_content_type: str, file_path: str): def _get_loader(self, filename: str, file_content_type: str, file_path: str):
@ -221,7 +227,6 @@ class Loader:
url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"), url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"),
api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"), api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"),
mime_type=file_content_type, mime_type=file_content_type,
user=self.user,
) )
elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"): elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
if self._is_text_file(file_ext, file_content_type): if self._is_text_file(file_ext, file_content_type):
@ -230,6 +235,7 @@ class Loader:
loader = TikaLoader( loader = TikaLoader(
url=self.kwargs.get("TIKA_SERVER_URL"), url=self.kwargs.get("TIKA_SERVER_URL"),
file_path=file_path, file_path=file_path,
mime_type=file_content_type,
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"), extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
) )
elif ( elif (
@ -257,15 +263,10 @@ class Loader:
"tiff", "tiff",
] ]
): ):
api_base_url = self.kwargs.get("DATALAB_MARKER_API_BASE_URL", "")
if not api_base_url or api_base_url.strip() == "":
api_base_url = "https://www.datalab.to/api/v1/marker" # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349
loader = DatalabMarkerLoader( loader = DatalabMarkerLoader(
file_path=file_path, file_path=file_path,
api_key=self.kwargs["DATALAB_MARKER_API_KEY"], api_key=self.kwargs["DATALAB_MARKER_API_KEY"],
api_base_url=api_base_url, langs=self.kwargs.get("DATALAB_MARKER_LANGS"),
additional_config=self.kwargs.get("DATALAB_MARKER_ADDITIONAL_CONFIG"),
use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False), use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False),
skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False), skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False),
force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False), force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False),
@ -276,7 +277,6 @@ class Loader:
disable_image_extraction=self.kwargs.get( disable_image_extraction=self.kwargs.get(
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False "DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False
), ),
format_lines=self.kwargs.get("DATALAB_MARKER_FORMAT_LINES", False),
output_format=self.kwargs.get( output_format=self.kwargs.get(
"DATALAB_MARKER_OUTPUT_FORMAT", "markdown" "DATALAB_MARKER_OUTPUT_FORMAT", "markdown"
), ),
@ -285,67 +285,38 @@ class Loader:
if self._is_text_file(file_ext, file_content_type): if self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True) loader = TextLoader(file_path, autodetect_encoding=True)
else: else:
# Build params for DoclingLoader
params = self.kwargs.get("DOCLING_PARAMS", {})
if not isinstance(params, dict):
try:
params = json.loads(params)
except json.JSONDecodeError:
log.error("Invalid DOCLING_PARAMS format, expected JSON object")
params = {}
loader = DoclingLoader( loader = DoclingLoader(
url=self.kwargs.get("DOCLING_SERVER_URL"), url=self.kwargs.get("DOCLING_SERVER_URL"),
api_key=self.kwargs.get("DOCLING_API_KEY", None),
file_path=file_path, file_path=file_path,
mime_type=file_content_type, mime_type=file_content_type,
params=params, params={
"ocr_engine": self.kwargs.get("DOCLING_OCR_ENGINE"),
"ocr_lang": self.kwargs.get("DOCLING_OCR_LANG"),
"do_picture_classification": self.kwargs.get(
"DOCLING_DO_PICTURE_DESCRIPTION"
),
},
) )
elif ( elif (
self.engine == "document_intelligence" self.engine == "document_intelligence"
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != "" and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
and self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != ""
and ( and (
file_ext in ["pdf", "docx", "ppt", "pptx"] file_ext in ["pdf", "xls", "xlsx", "docx", "ppt", "pptx"]
or file_content_type or file_content_type
in [ in [
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document", "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-powerpoint", "application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation", "application/vnd.openxmlformats-officedocument.presentationml.presentation",
] ]
) )
): ):
if self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != "": loader = AzureAIDocumentIntelligenceLoader(
loader = AzureAIDocumentIntelligenceLoader(
file_path=file_path,
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
)
else:
loader = AzureAIDocumentIntelligenceLoader(
file_path=file_path,
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
azure_credential=DefaultAzureCredential(),
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
)
elif self.engine == "mineru" and file_ext in [
"pdf"
]: # MinerU currently only supports PDF
mineru_timeout = self.kwargs.get("MINERU_API_TIMEOUT", 300)
if mineru_timeout:
try:
mineru_timeout = int(mineru_timeout)
except ValueError:
mineru_timeout = 300
loader = MinerULoader(
file_path=file_path, file_path=file_path,
api_mode=self.kwargs.get("MINERU_API_MODE", "local"), api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
api_url=self.kwargs.get("MINERU_API_URL", "http://localhost:8000"), api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
api_key=self.kwargs.get("MINERU_API_KEY", ""),
params=self.kwargs.get("MINERU_PARAMS", {}),
timeout=mineru_timeout,
) )
elif ( elif (
self.engine == "mistral_ocr" self.engine == "mistral_ocr"
@ -354,9 +325,16 @@ class Loader:
in ["pdf"] # Mistral OCR currently only supports PDF and images in ["pdf"] # Mistral OCR currently only supports PDF and images
): ):
loader = MistralLoader( loader = MistralLoader(
base_url=self.kwargs.get("MISTRAL_OCR_API_BASE_URL"), api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), file_path=file_path
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), )
file_path=file_path, elif (
self.engine == "external"
and self.kwargs.get("MISTRAL_OCR_API_KEY") != ""
and file_ext
in ["pdf"] # Mistral OCR currently only supports PDF and images
):
loader = MistralLoader(
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), file_path=file_path
) )
else: else:
if file_ext == "pdf": if file_ext == "pdf":
@ -393,8 +371,6 @@ class Loader:
loader = UnstructuredPowerPointLoader(file_path) loader = UnstructuredPowerPointLoader(file_path)
elif file_ext == "msg": elif file_ext == "msg":
loader = OutlookMessageLoader(file_path) loader = OutlookMessageLoader(file_path)
elif file_ext == "odt":
loader = UnstructuredODTLoader(file_path)
elif self._is_text_file(file_ext, file_content_type): elif self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True) loader = TextLoader(file_path, autodetect_encoding=True)
else: else:

View file

@ -1,524 +0,0 @@
import os
import time
import requests
import logging
import tempfile
import zipfile
from typing import List, Optional
from langchain_core.documents import Document
from fastapi import HTTPException, status
log = logging.getLogger(__name__)
class MinerULoader:
"""
MinerU document parser loader supporting both Cloud API and Local API modes.
Cloud API: Uses MinerU managed service with async task-based processing
Local API: Uses self-hosted MinerU API with synchronous processing
"""
def __init__(
self,
file_path: str,
api_mode: str = "local",
api_url: str = "http://localhost:8000",
api_key: str = "",
params: dict = None,
timeout: Optional[int] = 300,
):
self.file_path = file_path
self.api_mode = api_mode.lower()
self.api_url = api_url.rstrip("/")
self.api_key = api_key
self.timeout = timeout
# Parse params dict with defaults
self.params = params or {}
self.enable_ocr = params.get("enable_ocr", False)
self.enable_formula = params.get("enable_formula", True)
self.enable_table = params.get("enable_table", True)
self.language = params.get("language", "en")
self.model_version = params.get("model_version", "pipeline")
self.page_ranges = self.params.pop("page_ranges", "")
# Validate API mode
if self.api_mode not in ["local", "cloud"]:
raise ValueError(
f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'"
)
# Validate Cloud API requirements
if self.api_mode == "cloud" and not self.api_key:
raise ValueError("API key is required for Cloud API mode")
def load(self) -> List[Document]:
"""
Main entry point for loading and parsing the document.
Routes to Cloud or Local API based on api_mode.
"""
try:
if self.api_mode == "cloud":
return self._load_cloud_api()
else:
return self._load_local_api()
except Exception as e:
log.error(f"Error loading document with MinerU: {e}")
raise
def _load_local_api(self) -> List[Document]:
"""
Load document using Local API (synchronous).
Posts file to /file_parse endpoint and gets immediate response.
"""
log.info(f"Using MinerU Local API at {self.api_url}")
filename = os.path.basename(self.file_path)
# Build form data for Local API
form_data = {
**self.params,
"return_md": "true",
}
# Page ranges (Local API uses start_page_id and end_page_id)
if self.page_ranges:
# For simplicity, if page_ranges is specified, log a warning
# Full page range parsing would require parsing the string
log.warning(
f"Page ranges '{self.page_ranges}' specified but Local API uses different format. "
"Consider using start_page_id/end_page_id parameters if needed."
)
try:
with open(self.file_path, "rb") as f:
files = {"files": (filename, f, "application/octet-stream")}
log.info(f"Sending file to MinerU Local API: {filename}")
log.debug(f"Local API parameters: {form_data}")
response = requests.post(
f"{self.api_url}/file_parse",
data=form_data,
files=files,
timeout=self.timeout,
)
response.raise_for_status()
except FileNotFoundError:
raise HTTPException(
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
)
except requests.Timeout:
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT,
detail="MinerU Local API request timed out",
)
except requests.HTTPError as e:
error_detail = f"MinerU Local API request failed: {e}"
if e.response is not None:
try:
error_data = e.response.json()
error_detail += f" - {error_data}"
except:
error_detail += f" - {e.response.text}"
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error calling MinerU Local API: {str(e)}",
)
# Parse response
try:
result = response.json()
except ValueError as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid JSON response from MinerU Local API: {e}",
)
# Extract markdown content from response
if "results" not in result:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail="MinerU Local API response missing 'results' field",
)
results = result["results"]
if not results:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="MinerU returned empty results",
)
# Get the first (and typically only) result
file_result = list(results.values())[0]
markdown_content = file_result.get("md_content", "")
if not markdown_content:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="MinerU returned empty markdown content",
)
log.info(f"Successfully parsed document with MinerU Local API: {filename}")
# Create metadata
metadata = {
"source": filename,
"api_mode": "local",
"backend": result.get("backend", "unknown"),
"version": result.get("version", "unknown"),
}
return [Document(page_content=markdown_content, metadata=metadata)]
def _load_cloud_api(self) -> List[Document]:
"""
Load document using Cloud API (asynchronous).
Uses batch upload endpoint to avoid need for public file URLs.
"""
log.info(f"Using MinerU Cloud API at {self.api_url}")
filename = os.path.basename(self.file_path)
# Step 1: Request presigned upload URL
batch_id, upload_url = self._request_upload_url(filename)
# Step 2: Upload file to presigned URL
self._upload_to_presigned_url(upload_url)
# Step 3: Poll for results
result = self._poll_batch_status(batch_id, filename)
# Step 4: Download and extract markdown from ZIP
markdown_content = self._download_and_extract_zip(
result["full_zip_url"], filename
)
log.info(f"Successfully parsed document with MinerU Cloud API: {filename}")
# Create metadata
metadata = {
"source": filename,
"api_mode": "cloud",
"batch_id": batch_id,
}
return [Document(page_content=markdown_content, metadata=metadata)]
def _request_upload_url(self, filename: str) -> tuple:
"""
Request presigned upload URL from Cloud API.
Returns (batch_id, upload_url).
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
# Build request body
request_body = {
**self.params,
"files": [
{
"name": filename,
"is_ocr": self.enable_ocr,
}
],
}
# Add page ranges if specified
if self.page_ranges:
request_body["files"][0]["page_ranges"] = self.page_ranges
log.info(f"Requesting upload URL for: {filename}")
log.debug(f"Cloud API request body: {request_body}")
try:
response = requests.post(
f"{self.api_url}/file-urls/batch",
headers=headers,
json=request_body,
timeout=30,
)
response.raise_for_status()
except requests.HTTPError as e:
error_detail = f"Failed to request upload URL: {e}"
if e.response is not None:
try:
error_data = e.response.json()
error_detail += f" - {error_data.get('msg', error_data)}"
except:
error_detail += f" - {e.response.text}"
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error requesting upload URL: {str(e)}",
)
try:
result = response.json()
except ValueError as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid JSON response: {e}",
)
# Check for API error response
if result.get("code") != 0:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}",
)
data = result.get("data", {})
batch_id = data.get("batch_id")
file_urls = data.get("file_urls", [])
if not batch_id or not file_urls:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail="MinerU Cloud API response missing batch_id or file_urls",
)
upload_url = file_urls[0]
log.info(f"Received upload URL for batch: {batch_id}")
return batch_id, upload_url
def _upload_to_presigned_url(self, upload_url: str) -> None:
"""
Upload file to presigned URL (no authentication needed).
"""
log.info(f"Uploading file to presigned URL")
try:
with open(self.file_path, "rb") as f:
response = requests.put(
upload_url,
data=f,
timeout=self.timeout,
)
response.raise_for_status()
except FileNotFoundError:
raise HTTPException(
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
)
except requests.Timeout:
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT,
detail="File upload to presigned URL timed out",
)
except requests.HTTPError as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Failed to upload file to presigned URL: {e}",
)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error uploading file: {str(e)}",
)
log.info("File uploaded successfully")
def _poll_batch_status(self, batch_id: str, filename: str) -> dict:
"""
Poll batch status until completion.
Returns the result dict for the file.
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
}
max_iterations = 300 # 10 minutes max (2 seconds per iteration)
poll_interval = 2 # seconds
log.info(f"Polling batch status: {batch_id}")
for iteration in range(max_iterations):
try:
response = requests.get(
f"{self.api_url}/extract-results/batch/{batch_id}",
headers=headers,
timeout=30,
)
response.raise_for_status()
except requests.HTTPError as e:
error_detail = f"Failed to poll batch status: {e}"
if e.response is not None:
try:
error_data = e.response.json()
error_detail += f" - {error_data.get('msg', error_data)}"
except:
error_detail += f" - {e.response.text}"
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error polling batch status: {str(e)}",
)
try:
result = response.json()
except ValueError as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid JSON response while polling: {e}",
)
# Check for API error response
if result.get("code") != 0:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}",
)
data = result.get("data", {})
extract_result = data.get("extract_result", [])
# Find our file in the batch results
file_result = None
for item in extract_result:
if item.get("file_name") == filename:
file_result = item
break
if not file_result:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"File {filename} not found in batch results",
)
state = file_result.get("state")
if state == "done":
log.info(f"Processing complete for {filename}")
return file_result
elif state == "failed":
error_msg = file_result.get("err_msg", "Unknown error")
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"MinerU processing failed: {error_msg}",
)
elif state in ["waiting-file", "pending", "running", "converting"]:
# Still processing
if iteration % 10 == 0: # Log every 20 seconds
log.info(
f"Processing status: {state} (iteration {iteration + 1}/{max_iterations})"
)
time.sleep(poll_interval)
else:
log.warning(f"Unknown state: {state}")
time.sleep(poll_interval)
# Timeout
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT,
detail="MinerU processing timed out after 10 minutes",
)
def _download_and_extract_zip(self, zip_url: str, filename: str) -> str:
"""
Download ZIP file from CDN and extract markdown content.
Returns the markdown content as a string.
"""
log.info(f"Downloading results from: {zip_url}")
try:
response = requests.get(zip_url, timeout=60)
response.raise_for_status()
except requests.HTTPError as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Failed to download results ZIP: {e}",
)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error downloading results: {str(e)}",
)
# Save ZIP to temporary file and extract
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
tmp_zip.write(response.content)
tmp_zip_path = tmp_zip.name
with tempfile.TemporaryDirectory() as tmp_dir:
# Extract ZIP
with zipfile.ZipFile(tmp_zip_path, "r") as zip_ref:
zip_ref.extractall(tmp_dir)
# Find markdown file - search recursively for any .md file
markdown_content = None
found_md_path = None
# First, list all files in the ZIP for debugging
all_files = []
for root, dirs, files in os.walk(tmp_dir):
for file in files:
full_path = os.path.join(root, file)
all_files.append(full_path)
# Look for any .md file
if file.endswith(".md"):
found_md_path = full_path
log.info(f"Found markdown file at: {full_path}")
try:
with open(full_path, "r", encoding="utf-8") as f:
markdown_content = f.read()
if (
markdown_content
): # Use the first non-empty markdown file
break
except Exception as e:
log.warning(f"Failed to read {full_path}: {e}")
if markdown_content:
break
if markdown_content is None:
log.error(f"Available files in ZIP: {all_files}")
# Try to provide more helpful error message
md_files = [f for f in all_files if f.endswith(".md")]
if md_files:
error_msg = (
f"Found .md files but couldn't read them: {md_files}"
)
else:
error_msg = (
f"No .md files found in ZIP. Available files: {all_files}"
)
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=error_msg,
)
# Clean up temporary ZIP file
os.unlink(tmp_zip_path)
except zipfile.BadZipFile as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid ZIP file received: {e}",
)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error extracting ZIP: {str(e)}",
)
if not markdown_content:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="Extracted markdown content is empty",
)
log.info(
f"Successfully extracted markdown content ({len(markdown_content)} characters)"
)
return markdown_content

View file

@ -9,29 +9,23 @@ from typing import List, Dict, Any
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from langchain_core.documents import Document from langchain_core.documents import Document
from open_webui.env import GLOBAL_LOG_LEVEL from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class MistralLoader: class MistralLoader:
""" """
Enhanced Mistral OCR loader with both sync and async support. Enhanced Mistral OCR loader with both sync and async support.
Loads documents by processing them through the Mistral OCR API. Loads documents by processing them through the Mistral OCR API.
Performance Optimizations:
- Differentiated timeouts for different operations
- Intelligent retry logic with exponential backoff
- Memory-efficient file streaming for large files
- Connection pooling and keepalive optimization
- Semaphore-based concurrency control for batch processing
- Enhanced error handling with retryable error classification
""" """
BASE_API_URL = "https://api.mistral.ai/v1"
def __init__( def __init__(
self, self,
base_url: str,
api_key: str, api_key: str,
file_path: str, file_path: str,
timeout: int = 300, # 5 minutes default timeout: int = 300, # 5 minutes default
@ -53,49 +47,23 @@ class MistralLoader:
if not os.path.exists(file_path): if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found at {file_path}") raise FileNotFoundError(f"File not found at {file_path}")
self.base_url = (
base_url.rstrip("/") if base_url else "https://api.mistral.ai/v1"
)
self.api_key = api_key self.api_key = api_key
self.file_path = file_path self.file_path = file_path
self.timeout = timeout self.timeout = timeout
self.max_retries = max_retries self.max_retries = max_retries
self.debug = enable_debug_logging self.debug = enable_debug_logging
# PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations # Pre-compute file info for performance
# This prevents long-running OCR operations from affecting quick operations
# and improves user experience by failing fast on operations that should be quick
self.upload_timeout = min(
timeout, 120
) # Cap upload at 2 minutes - prevents hanging on large files
self.url_timeout = (
30 # URL requests should be fast - fail quickly if API is slow
)
self.ocr_timeout = (
timeout # OCR can take the full timeout - this is the heavy operation
)
self.cleanup_timeout = (
30 # Cleanup should be quick - don't hang on file deletion
)
# PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls
# This avoids multiple os.path.basename() and os.path.getsize() calls during processing
self.file_name = os.path.basename(file_path) self.file_name = os.path.basename(file_path)
self.file_size = os.path.getsize(file_path) self.file_size = os.path.getsize(file_path)
# ENHANCEMENT: Added User-Agent for better API tracking and debugging
self.headers = { self.headers = {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
"User-Agent": "OpenWebUI-MistralLoader/2.0", # Helps API provider track usage "User-Agent": "OpenWebUI-MistralLoader/2.0",
} }
def _debug_log(self, message: str, *args) -> None: def _debug_log(self, message: str, *args) -> None:
""" """Conditional debug logging for performance."""
PERFORMANCE OPTIMIZATION: Conditional debug logging for performance.
Only processes debug messages when debug mode is enabled, avoiding
string formatting overhead in production environments.
"""
if self.debug: if self.debug:
log.debug(message, *args) log.debug(message, *args)
@ -147,118 +115,53 @@ class MistralLoader:
log.error(f"Unexpected error processing response: {e}") log.error(f"Unexpected error processing response: {e}")
raise raise
def _is_retryable_error(self, error: Exception) -> bool:
"""
ENHANCEMENT: Intelligent error classification for retry logic.
Determines if an error is retryable based on its type and status code.
This prevents wasting time retrying errors that will never succeed
(like authentication errors) while ensuring transient errors are retried.
Retryable errors:
- Network connection errors (temporary network issues)
- Timeouts (server might be temporarily overloaded)
- Server errors (5xx status codes - server-side issues)
- Rate limiting (429 status - temporary throttling)
Non-retryable errors:
- Authentication errors (401, 403 - won't fix with retry)
- Bad request errors (400 - malformed request)
- Not found errors (404 - resource doesn't exist)
"""
if isinstance(error, requests.exceptions.ConnectionError):
return True # Network issues are usually temporary
if isinstance(error, requests.exceptions.Timeout):
return True # Timeouts might resolve on retry
if isinstance(error, requests.exceptions.HTTPError):
# Only retry on server errors (5xx) or rate limits (429)
if hasattr(error, "response") and error.response is not None:
status_code = error.response.status_code
return status_code >= 500 or status_code == 429
return False
if isinstance(
error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)
):
return True # Async network/timeout errors are retryable
if isinstance(error, aiohttp.ClientResponseError):
return error.status >= 500 or error.status == 429
return False # All other errors are non-retryable
def _retry_request_sync(self, request_func, *args, **kwargs): def _retry_request_sync(self, request_func, *args, **kwargs):
""" """Synchronous retry logic with exponential backoff."""
ENHANCEMENT: Synchronous retry logic with intelligent error classification.
Uses exponential backoff with jitter to avoid thundering herd problems.
The wait time increases exponentially but is capped at 30 seconds to
prevent excessive delays. Only retries errors that are likely to succeed
on subsequent attempts.
"""
for attempt in range(self.max_retries): for attempt in range(self.max_retries):
try: try:
return request_func(*args, **kwargs) return request_func(*args, **kwargs)
except Exception as e: except (requests.exceptions.RequestException, Exception) as e:
if attempt == self.max_retries - 1 or not self._is_retryable_error(e): if attempt == self.max_retries - 1:
raise raise
# PERFORMANCE OPTIMIZATION: Exponential backoff with cap wait_time = (2**attempt) + 0.5
# Prevents overwhelming the server while ensuring reasonable retry delays
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
log.warning( log.warning(
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. " f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
f"Retrying in {wait_time}s..."
) )
time.sleep(wait_time) time.sleep(wait_time)
async def _retry_request_async(self, request_func, *args, **kwargs): async def _retry_request_async(self, request_func, *args, **kwargs):
""" """Async retry logic with exponential backoff."""
ENHANCEMENT: Async retry logic with intelligent error classification.
Async version of retry logic that doesn't block the event loop during
wait periods. Uses the same exponential backoff strategy as sync version.
"""
for attempt in range(self.max_retries): for attempt in range(self.max_retries):
try: try:
return await request_func(*args, **kwargs) return await request_func(*args, **kwargs)
except Exception as e: except (aiohttp.ClientError, asyncio.TimeoutError) as e:
if attempt == self.max_retries - 1 or not self._is_retryable_error(e): if attempt == self.max_retries - 1:
raise raise
# PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff wait_time = (2**attempt) + 0.5
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
log.warning( log.warning(
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. " f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
f"Retrying in {wait_time}s..."
) )
await asyncio.sleep(wait_time) # Non-blocking wait await asyncio.sleep(wait_time)
def _upload_file(self) -> str: def _upload_file(self) -> str:
""" """Uploads the file to Mistral for OCR processing (sync version)."""
PERFORMANCE OPTIMIZATION: Enhanced file upload with streaming consideration.
Uploads the file to Mistral for OCR processing (sync version).
Uses context manager for file handling to ensure proper resource cleanup.
Although streaming is not enabled for this endpoint, the file is opened
in a context manager to minimize memory usage duration.
"""
log.info("Uploading file to Mistral API") log.info("Uploading file to Mistral API")
url = f"{self.base_url}/files" url = f"{self.BASE_API_URL}/files"
file_name = os.path.basename(self.file_path)
def upload_request(): def upload_request():
# MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime
# This ensures the file is closed immediately after reading, reducing memory usage
with open(self.file_path, "rb") as f: with open(self.file_path, "rb") as f:
files = {"file": (self.file_name, f, "application/pdf")} files = {"file": (file_name, f, "application/pdf")}
data = {"purpose": "ocr"} data = {"purpose": "ocr"}
# NOTE: stream=False is required for this endpoint
# The Mistral API doesn't support chunked uploads for this endpoint
response = requests.post( response = requests.post(
url, url,
headers=self.headers, headers=self.headers,
files=files, files=files,
data=data, data=data,
timeout=self.upload_timeout, # Use specialized upload timeout timeout=self.timeout,
stream=False, # Keep as False for this endpoint
) )
return self._handle_response(response) return self._handle_response(response)
@ -276,7 +179,7 @@ class MistralLoader:
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str: async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
"""Async file upload with streaming for better memory efficiency.""" """Async file upload with streaming for better memory efficiency."""
url = f"{self.base_url}/files" url = f"{self.BASE_API_URL}/files"
async def upload_request(): async def upload_request():
# Create multipart writer for streaming upload # Create multipart writer for streaming upload
@ -306,7 +209,7 @@ class MistralLoader:
url, url,
data=writer, data=writer,
headers=self.headers, headers=self.headers,
timeout=aiohttp.ClientTimeout(total=self.upload_timeout), timeout=aiohttp.ClientTimeout(total=self.timeout),
) as response: ) as response:
return await self._handle_response_async(response) return await self._handle_response_async(response)
@ -322,13 +225,13 @@ class MistralLoader:
def _get_signed_url(self, file_id: str) -> str: def _get_signed_url(self, file_id: str) -> str:
"""Retrieves a temporary signed URL for the uploaded file (sync version).""" """Retrieves a temporary signed URL for the uploaded file (sync version)."""
log.info(f"Getting signed URL for file ID: {file_id}") log.info(f"Getting signed URL for file ID: {file_id}")
url = f"{self.base_url}/files/{file_id}/url" url = f"{self.BASE_API_URL}/files/{file_id}/url"
params = {"expiry": 1} params = {"expiry": 1}
signed_url_headers = {**self.headers, "Accept": "application/json"} signed_url_headers = {**self.headers, "Accept": "application/json"}
def url_request(): def url_request():
response = requests.get( response = requests.get(
url, headers=signed_url_headers, params=params, timeout=self.url_timeout url, headers=signed_url_headers, params=params, timeout=self.timeout
) )
return self._handle_response(response) return self._handle_response(response)
@ -347,7 +250,7 @@ class MistralLoader:
self, session: aiohttp.ClientSession, file_id: str self, session: aiohttp.ClientSession, file_id: str
) -> str: ) -> str:
"""Async signed URL retrieval.""" """Async signed URL retrieval."""
url = f"{self.base_url}/files/{file_id}/url" url = f"{self.BASE_API_URL}/files/{file_id}/url"
params = {"expiry": 1} params = {"expiry": 1}
headers = {**self.headers, "Accept": "application/json"} headers = {**self.headers, "Accept": "application/json"}
@ -358,7 +261,7 @@ class MistralLoader:
url, url,
headers=headers, headers=headers,
params=params, params=params,
timeout=aiohttp.ClientTimeout(total=self.url_timeout), timeout=aiohttp.ClientTimeout(total=self.timeout),
) as response: ) as response:
return await self._handle_response_async(response) return await self._handle_response_async(response)
@ -374,7 +277,7 @@ class MistralLoader:
def _process_ocr(self, signed_url: str) -> Dict[str, Any]: def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
"""Sends the signed URL to the OCR endpoint for processing (sync version).""" """Sends the signed URL to the OCR endpoint for processing (sync version)."""
log.info("Processing OCR via Mistral API") log.info("Processing OCR via Mistral API")
url = f"{self.base_url}/ocr" url = f"{self.BASE_API_URL}/ocr"
ocr_headers = { ocr_headers = {
**self.headers, **self.headers,
"Content-Type": "application/json", "Content-Type": "application/json",
@ -391,7 +294,7 @@ class MistralLoader:
def ocr_request(): def ocr_request():
response = requests.post( response = requests.post(
url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout url, headers=ocr_headers, json=payload, timeout=self.timeout
) )
return self._handle_response(response) return self._handle_response(response)
@ -408,7 +311,7 @@ class MistralLoader:
self, session: aiohttp.ClientSession, signed_url: str self, session: aiohttp.ClientSession, signed_url: str
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Async OCR processing with timing metrics.""" """Async OCR processing with timing metrics."""
url = f"{self.base_url}/ocr" url = f"{self.BASE_API_URL}/ocr"
headers = { headers = {
**self.headers, **self.headers,
@ -433,7 +336,7 @@ class MistralLoader:
url, url,
json=payload, json=payload,
headers=headers, headers=headers,
timeout=aiohttp.ClientTimeout(total=self.ocr_timeout), timeout=aiohttp.ClientTimeout(total=self.timeout),
) as response: ) as response:
ocr_response = await self._handle_response_async(response) ocr_response = await self._handle_response_async(response)
@ -447,12 +350,10 @@ class MistralLoader:
def _delete_file(self, file_id: str) -> None: def _delete_file(self, file_id: str) -> None:
"""Deletes the file from Mistral storage (sync version).""" """Deletes the file from Mistral storage (sync version)."""
log.info(f"Deleting uploaded file ID: {file_id}") log.info(f"Deleting uploaded file ID: {file_id}")
url = f"{self.base_url}/files/{file_id}" url = f"{self.BASE_API_URL}/files/{file_id}"
try: try:
response = requests.delete( response = requests.delete(url, headers=self.headers, timeout=30)
url, headers=self.headers, timeout=self.cleanup_timeout
)
delete_response = self._handle_response(response) delete_response = self._handle_response(response)
log.info(f"File deleted successfully: {delete_response}") log.info(f"File deleted successfully: {delete_response}")
except Exception as e: except Exception as e:
@ -468,10 +369,10 @@ class MistralLoader:
async def delete_request(): async def delete_request():
self._debug_log(f"Deleting file ID: {file_id}") self._debug_log(f"Deleting file ID: {file_id}")
async with session.delete( async with session.delete(
url=f"{self.base_url}/files/{file_id}", url=f"{self.BASE_API_URL}/files/{file_id}",
headers=self.headers, headers=self.headers,
timeout=aiohttp.ClientTimeout( timeout=aiohttp.ClientTimeout(
total=self.cleanup_timeout total=30
), # Shorter timeout for cleanup ), # Shorter timeout for cleanup
) as response: ) as response:
return await self._handle_response_async(response) return await self._handle_response_async(response)
@ -487,40 +388,29 @@ class MistralLoader:
async def _get_session(self): async def _get_session(self):
"""Context manager for HTTP session with optimized settings.""" """Context manager for HTTP session with optimized settings."""
connector = aiohttp.TCPConnector( connector = aiohttp.TCPConnector(
limit=20, # Increased total connection limit for better throughput limit=10, # Total connection limit
limit_per_host=10, # Increased per-host limit for API endpoints limit_per_host=5, # Per-host connection limit
ttl_dns_cache=600, # Longer DNS cache TTL (10 minutes) ttl_dns_cache=300, # DNS cache TTL
use_dns_cache=True, use_dns_cache=True,
keepalive_timeout=60, # Increased keepalive for connection reuse keepalive_timeout=30,
enable_cleanup_closed=True, enable_cleanup_closed=True,
force_close=False, # Allow connection reuse
resolver=aiohttp.AsyncResolver(), # Use async DNS resolver
)
timeout = aiohttp.ClientTimeout(
total=self.timeout,
connect=30, # Connection timeout
sock_read=60, # Socket read timeout
) )
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
connector=connector, connector=connector,
timeout=timeout, timeout=aiohttp.ClientTimeout(total=self.timeout),
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"}, headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
raise_for_status=False, # We handle status codes manually
trust_env=True,
) as session: ) as session:
yield session yield session
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]: def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
"""Process OCR results into Document objects with enhanced metadata and memory efficiency.""" """Process OCR results into Document objects with enhanced metadata."""
pages_data = ocr_response.get("pages") pages_data = ocr_response.get("pages")
if not pages_data: if not pages_data:
log.warning("No pages found in OCR response.") log.warning("No pages found in OCR response.")
return [ return [
Document( Document(
page_content="No text content found", page_content="No text content found", metadata={"error": "no_pages"}
metadata={"error": "no_pages", "file_name": self.file_name},
) )
] ]
@ -528,44 +418,41 @@ class MistralLoader:
total_pages = len(pages_data) total_pages = len(pages_data)
skipped_pages = 0 skipped_pages = 0
# Process pages in a memory-efficient way
for page_data in pages_data: for page_data in pages_data:
page_content = page_data.get("markdown") page_content = page_data.get("markdown")
page_index = page_data.get("index") # API uses 0-based index page_index = page_data.get("index") # API uses 0-based index
if page_content is None or page_index is None: if page_content is not None and page_index is not None:
# Clean up content efficiently
cleaned_content = (
page_content.strip()
if isinstance(page_content, str)
else str(page_content)
)
if cleaned_content: # Only add non-empty pages
documents.append(
Document(
page_content=cleaned_content,
metadata={
"page": page_index, # 0-based index from API
"page_label": page_index
+ 1, # 1-based label for convenience
"total_pages": total_pages,
"file_name": self.file_name,
"file_size": self.file_size,
"processing_engine": "mistral-ocr",
},
)
)
else:
skipped_pages += 1
self._debug_log(f"Skipping empty page {page_index}")
else:
skipped_pages += 1 skipped_pages += 1
self._debug_log( self._debug_log(
f"Skipping page due to missing 'markdown' or 'index'. Data keys: {list(page_data.keys())}" f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}"
) )
continue
# Clean up content efficiently with early exit for empty content
if isinstance(page_content, str):
cleaned_content = page_content.strip()
else:
cleaned_content = str(page_content).strip()
if not cleaned_content:
skipped_pages += 1
self._debug_log(f"Skipping empty page {page_index}")
continue
# Create document with optimized metadata
documents.append(
Document(
page_content=cleaned_content,
metadata={
"page": page_index, # 0-based index from API
"page_label": page_index + 1, # 1-based label for convenience
"total_pages": total_pages,
"file_name": self.file_name,
"file_size": self.file_size,
"processing_engine": "mistral-ocr",
"content_length": len(cleaned_content),
},
)
)
if skipped_pages > 0: if skipped_pages > 0:
log.info( log.info(
@ -580,11 +467,7 @@ class MistralLoader:
return [ return [
Document( Document(
page_content="No valid text content found in document", page_content="No valid text content found in document",
metadata={ metadata={"error": "no_valid_pages", "total_pages": total_pages},
"error": "no_valid_pages",
"total_pages": total_pages,
"file_name": self.file_name,
},
) )
] ]
@ -702,14 +585,12 @@ class MistralLoader:
@staticmethod @staticmethod
async def load_multiple_async( async def load_multiple_async(
loaders: List["MistralLoader"], loaders: List["MistralLoader"],
max_concurrent: int = 5, # Limit concurrent requests
) -> List[List[Document]]: ) -> List[List[Document]]:
""" """
Process multiple files concurrently with controlled concurrency. Process multiple files concurrently for maximum performance.
Args: Args:
loaders: List of MistralLoader instances loaders: List of MistralLoader instances
max_concurrent: Maximum number of concurrent requests
Returns: Returns:
List of document lists, one for each loader List of document lists, one for each loader
@ -717,20 +598,11 @@ class MistralLoader:
if not loaders: if not loaders:
return [] return []
log.info( log.info(f"Starting concurrent processing of {len(loaders)} files")
f"Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent"
)
start_time = time.time() start_time = time.time()
# Use semaphore to control concurrency # Process all files concurrently
semaphore = asyncio.Semaphore(max_concurrent) tasks = [loader.load_async() for loader in loaders]
async def process_with_semaphore(loader: "MistralLoader") -> List[Document]:
async with semaphore:
return await loader.load_async()
# Process all files with controlled concurrency
tasks = [process_with_semaphore(loader) for loader in loaders]
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
# Handle any exceptions in results # Handle any exceptions in results
@ -752,18 +624,10 @@ class MistralLoader:
else: else:
processed_results.append(result) processed_results.append(result)
# MONITORING: Log comprehensive batch processing statistics
total_time = time.time() - start_time total_time = time.time() - start_time
total_docs = sum(len(docs) for docs in processed_results) total_docs = sum(len(docs) for docs in processed_results)
success_count = sum(
1 for result in results if not isinstance(result, Exception)
)
failure_count = len(results) - success_count
log.info( log.info(
f"Batch processing completed in {total_time:.2f}s: " f"Batch processing completed in {total_time:.2f}s, produced {total_docs} total documents"
f"{success_count} files succeeded, {failure_count} files failed, "
f"produced {total_docs} total documents"
) )
return processed_results return processed_results

View file

@ -4,8 +4,10 @@ from typing import Iterator, List, Literal, Union
from langchain_core.document_loaders import BaseLoader from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class TavilyLoader(BaseLoader): class TavilyLoader(BaseLoader):

View file

@ -1,11 +1,12 @@
import logging import logging
from xml.etree.ElementTree import ParseError
from typing import Any, Dict, Generator, List, Optional, Sequence, Union from typing import Any, Dict, Generator, List, Optional, Sequence, Union
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
from langchain_core.documents import Document from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
ALLOWED_SCHEMES = {"http", "https"} ALLOWED_SCHEMES = {"http", "https"}
ALLOWED_NETLOCS = { ALLOWED_NETLOCS = {
@ -81,7 +82,6 @@ class YoutubeLoader:
TranscriptsDisabled, TranscriptsDisabled,
YouTubeTranscriptApi, YouTubeTranscriptApi,
) )
from youtube_transcript_api.proxies import GenericProxyConfig
except ImportError: except ImportError:
raise ImportError( raise ImportError(
'Could not import "youtube_transcript_api" Python package. ' 'Could not import "youtube_transcript_api" Python package. '
@ -89,16 +89,19 @@ class YoutubeLoader:
) )
if self.proxy_url: if self.proxy_url:
youtube_proxies = GenericProxyConfig( youtube_proxies = {
http_url=self.proxy_url, https_url=self.proxy_url "http": self.proxy_url,
) "https": self.proxy_url,
}
# Don't log complete URL because it might contain secrets
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...") log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
else: else:
youtube_proxies = None youtube_proxies = None
transcript_api = YouTubeTranscriptApi(proxy_config=youtube_proxies)
try: try:
transcript_list = transcript_api.list(self.video_id) transcript_list = YouTubeTranscriptApi.list_transcripts(
self.video_id, proxies=youtube_proxies
)
except Exception as e: except Exception as e:
log.exception("Loading YouTube transcript failed") log.exception("Loading YouTube transcript failed")
return [] return []
@ -107,37 +110,11 @@ class YoutubeLoader:
for lang in self.language: for lang in self.language:
try: try:
transcript = transcript_list.find_transcript([lang]) transcript = transcript_list.find_transcript([lang])
if transcript.is_generated:
log.debug(f"Found generated transcript for language '{lang}'")
try:
transcript = transcript_list.find_manually_created_transcript(
[lang]
)
log.debug(f"Found manual transcript for language '{lang}'")
except NoTranscriptFound:
log.debug(
f"No manual transcript found for language '{lang}', using generated"
)
pass
log.debug(f"Found transcript for language '{lang}'") log.debug(f"Found transcript for language '{lang}'")
try: transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
except ParseError:
log.debug(f"Empty or invalid transcript for language '{lang}'")
continue
if not transcript_pieces:
log.debug(f"Empty transcript for language '{lang}'")
continue
transcript_text = " ".join( transcript_text = " ".join(
map( map(
lambda transcript_piece: ( lambda transcript_piece: transcript_piece.text.strip(" "),
transcript_piece.text.strip(" ")
if hasattr(transcript_piece, "text")
else ""
),
transcript_pieces, transcript_pieces,
) )
) )
@ -154,11 +131,6 @@ class YoutubeLoader:
log.warning( log.warning(
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed." f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
) )
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list)) raise NoTranscriptFound(
f"No transcript found for any supported language. Verify if the video has transcripts, add more languages if needed."
async def aload(self) -> Generator[Document, None, None]: )
"""Asynchronously load YouTube transcripts into `Document` objects."""
import asyncio
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.load)

View file

@ -5,10 +5,12 @@ import numpy as np
from colbert.infra import ColBERTConfig from colbert.infra import ColBERTConfig
from colbert.modeling.checkpoint import Checkpoint from colbert.modeling.checkpoint import Checkpoint
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.models.base_reranker import BaseReranker from open_webui.retrieval.models.base_reranker import BaseReranker
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ColBERT(BaseReranker): class ColBERT(BaseReranker):

View file

@ -1,15 +1,13 @@
import logging import logging
import requests import requests
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from urllib.parse import quote
from open_webui.env import SRC_LOG_LEVELS
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS
from open_webui.retrieval.models.base_reranker import BaseReranker from open_webui.retrieval.models.base_reranker import BaseReranker
from open_webui.utils.headers import include_user_info_headers
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ExternalReranker(BaseReranker): class ExternalReranker(BaseReranker):
@ -18,16 +16,12 @@ class ExternalReranker(BaseReranker):
api_key: str, api_key: str,
url: str = "http://localhost:8080/v1/rerank", url: str = "http://localhost:8080/v1/rerank",
model: str = "reranker", model: str = "reranker",
timeout: Optional[int] = None,
): ):
self.api_key = api_key self.api_key = api_key
self.url = url self.url = url
self.model = model self.model = model
self.timeout = timeout
def predict( def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
self, sentences: List[Tuple[str, str]], user=None
) -> Optional[List[float]]:
query = sentences[0][0] query = sentences[0][0]
docs = [i[1] for i in sentences] docs = [i[1] for i in sentences]
@ -42,19 +36,13 @@ class ExternalReranker(BaseReranker):
log.info(f"ExternalReranker:predict:model {self.model}") log.info(f"ExternalReranker:predict:model {self.model}")
log.info(f"ExternalReranker:predict:query {query}") log.info(f"ExternalReranker:predict:query {query}")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.post( r = requests.post(
f"{self.url}", f"{self.url}",
headers=headers, headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
},
json=payload, json=payload,
timeout=self.timeout,
) )
r.raise_for_status() r.raise_for_status()

File diff suppressed because it is too large Load diff

View file

@ -11,8 +11,6 @@ from open_webui.retrieval.vector.main import (
SearchResult, SearchResult,
GetResult, GetResult,
) )
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.config import ( from open_webui.config import (
CHROMA_DATA_PATH, CHROMA_DATA_PATH,
CHROMA_HTTP_HOST, CHROMA_HTTP_HOST,
@ -24,8 +22,10 @@ from open_webui.config import (
CHROMA_CLIENT_AUTH_PROVIDER, CHROMA_CLIENT_AUTH_PROVIDER,
CHROMA_CLIENT_AUTH_CREDENTIALS, CHROMA_CLIENT_AUTH_CREDENTIALS,
) )
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ChromaClient(VectorDBBase): class ChromaClient(VectorDBBase):
@ -144,7 +144,7 @@ class ChromaClient(VectorDBBase):
ids = [item["id"] for item in items] ids = [item["id"] for item in items]
documents = [item["text"] for item in items] documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items] embeddings = [item["vector"] for item in items]
metadatas = [process_metadata(item["metadata"]) for item in items] metadatas = [item["metadata"] for item in items]
for batch in create_batches( for batch in create_batches(
api=self.client, api=self.client,
@ -164,7 +164,7 @@ class ChromaClient(VectorDBBase):
ids = [item["id"] for item in items] ids = [item["id"] for item in items]
documents = [item["text"] for item in items] documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items] embeddings = [item["vector"] for item in items]
metadatas = [process_metadata(item["metadata"]) for item in items] metadatas = [item["metadata"] for item in items]
collection.upsert( collection.upsert(
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas

View file

@ -2,8 +2,6 @@ from elasticsearch import Elasticsearch, BadRequestError
from typing import Optional from typing import Optional
import ssl import ssl
from elasticsearch.helpers import bulk, scan from elasticsearch.helpers import bulk, scan
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.retrieval.vector.main import ( from open_webui.retrieval.vector.main import (
VectorDBBase, VectorDBBase,
VectorItem, VectorItem,
@ -245,7 +243,7 @@ class ElasticsearchClient(VectorDBBase):
"collection": collection_name, "collection": collection_name,
"vector": item["vector"], "vector": item["vector"],
"text": item["text"], "text": item["text"],
"metadata": process_metadata(item["metadata"]), "metadata": item["metadata"],
}, },
} }
for item in batch for item in batch
@ -266,7 +264,7 @@ class ElasticsearchClient(VectorDBBase):
"collection": collection_name, "collection": collection_name,
"vector": item["vector"], "vector": item["vector"],
"text": item["text"], "text": item["text"],
"metadata": process_metadata(item["metadata"]), "metadata": item["metadata"],
}, },
"doc_as_upsert": True, "doc_as_upsert": True,
} }

View file

@ -1,12 +1,8 @@
from pymilvus import MilvusClient as Client from pymilvus import MilvusClient as Client
from pymilvus import FieldSchema, DataType from pymilvus import FieldSchema, DataType
from pymilvus import connections, Collection
import json import json
import logging import logging
from typing import Optional from typing import Optional
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.retrieval.vector.main import ( from open_webui.retrieval.vector.main import (
VectorDBBase, VectorDBBase,
VectorItem, VectorItem,
@ -22,11 +18,11 @@ from open_webui.config import (
MILVUS_HNSW_M, MILVUS_HNSW_M,
MILVUS_HNSW_EFCONSTRUCTION, MILVUS_HNSW_EFCONSTRUCTION,
MILVUS_IVF_FLAT_NLIST, MILVUS_IVF_FLAT_NLIST,
MILVUS_DISKANN_MAX_DEGREE,
MILVUS_DISKANN_SEARCH_LIST_SIZE,
) )
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class MilvusClient(VectorDBBase): class MilvusClient(VectorDBBase):
@ -131,18 +127,12 @@ class MilvusClient(VectorDBBase):
elif index_type == "IVF_FLAT": elif index_type == "IVF_FLAT":
index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST} index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST}
log.info(f"IVF_FLAT params: {index_creation_params}") log.info(f"IVF_FLAT params: {index_creation_params}")
elif index_type == "DISKANN":
index_creation_params = {
"max_degree": MILVUS_DISKANN_MAX_DEGREE,
"search_list_size": MILVUS_DISKANN_SEARCH_LIST_SIZE,
}
log.info(f"DISKANN params: {index_creation_params}")
elif index_type in ["FLAT", "AUTOINDEX"]: elif index_type in ["FLAT", "AUTOINDEX"]:
log.info(f"Using {index_type} index with no specific build-time params.") log.info(f"Using {index_type} index with no specific build-time params.")
else: else:
log.warning( log.warning(
f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. " f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
f"Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. " f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX. "
f"Milvus will use its default for the collection if this type is not directly supported for index creation." f"Milvus will use its default for the collection if this type is not directly supported for index creation."
) )
# For unsupported types, pass the type directly to Milvus; it might handle it or use a default. # For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
@ -195,54 +185,86 @@ class MilvusClient(VectorDBBase):
) )
return self._result_to_search_result(result) return self._result_to_search_result(result)
def query(self, collection_name: str, filter: dict, limit: int = -1): def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB) # Construct the filter string for querying
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace("-", "_")
if not self.has_collection(collection_name): if not self.has_collection(collection_name):
log.warning( log.warning(
f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}" f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
) )
return None return None
filter_string = " && ".join(
[
f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items()
]
)
max_limit = 16383 # The maximum number of records per request
all_results = []
if limit is None:
# Milvus default limit for query if not specified is 16384, but docs mention iteration.
# Let's set a practical high number if "all" is intended, or handle true pagination.
# For now, if limit is None, we'll fetch in batches up to a very large number.
# This part could be refined based on expected use cases for "get all".
# For this function signature, None implies "as many as possible" up to Milvus limits.
limit = (
16384 * 10
) # A large number to signify fetching many, will be capped by actual data or max_limit per call.
log.info(
f"Limit not specified for query, fetching up to {limit} results in batches."
)
filter_expressions = [] # Initialize offset and remaining to handle pagination
for key, value in filter.items(): offset = 0
if isinstance(value, str): remaining = limit
filter_expressions.append(f'metadata["{key}"] == "{value}"')
else:
filter_expressions.append(f'metadata["{key}"] == {value}')
filter_string = " && ".join(filter_expressions)
collection = Collection(f"{self.collection_prefix}_{collection_name}")
collection.load()
try: try:
log.info( log.info(
f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}" f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}"
) )
# Loop until there are no more items to fetch or the desired limit is reached
while remaining > 0:
current_fetch = min(
max_limit, remaining if isinstance(remaining, int) else max_limit
)
log.debug(
f"Querying with offset: {offset}, current_fetch: {current_fetch}"
)
iterator = collection.query_iterator( results = self.client.query(
expr=filter_string, collection_name=f"{self.collection_prefix}_{collection_name}",
output_fields=[ filter=filter_string,
"id", output_fields=[
"data", "id",
"metadata", "data",
], "metadata",
limit=limit if limit > 0 else -1, ], # Explicitly list needed fields. Vector not usually needed in query.
) limit=current_fetch,
offset=offset,
)
all_results = [] if not results:
while True: log.debug("No more results from query.")
batch = iterator.next()
if not batch:
iterator.close()
break break
all_results.extend(batch)
log.debug(f"Total results from query: {len(all_results)}") all_results.extend(results)
return self._result_to_get_result([all_results] if all_results else [[]]) results_count = len(results)
log.debug(f"Fetched {results_count} results in this batch.")
if isinstance(remaining, int):
remaining -= results_count
offset += results_count
# Break the loop if the results returned are less than the requested fetch count (means end of data)
if results_count < current_fetch:
log.debug(
"Fetched less than requested, assuming end of results for this query."
)
break
log.info(f"Total results from query: {len(all_results)}")
return self._result_to_get_result([all_results])
except Exception as e: except Exception as e:
log.exception( log.exception(
f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}" f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"
@ -257,7 +279,7 @@ class MilvusClient(VectorDBBase):
) )
# Using query with a trivial filter to get all items. # Using query with a trivial filter to get all items.
# This will use the paginated query logic. # This will use the paginated query logic.
return self.query(collection_name=collection_name, filter={}, limit=-1) return self.query(collection_name=collection_name, filter={}, limit=None)
def insert(self, collection_name: str, items: list[VectorItem]): def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created. # Insert the items into the collection, if the collection does not exist, it will be created.
@ -289,7 +311,7 @@ class MilvusClient(VectorDBBase):
"id": item["id"], "id": item["id"],
"vector": item["vector"], "vector": item["vector"],
"data": {"text": item["text"]}, "data": {"text": item["text"]},
"metadata": process_metadata(item["metadata"]), "metadata": item["metadata"],
} }
for item in items for item in items
], ],
@ -325,7 +347,7 @@ class MilvusClient(VectorDBBase):
"id": item["id"], "id": item["id"],
"vector": item["vector"], "vector": item["vector"],
"data": {"text": item["text"]}, "data": {"text": item["text"]},
"metadata": process_metadata(item["metadata"]), "metadata": item["metadata"],
} }
for item in items for item in items
], ],

View file

@ -1,287 +0,0 @@
import logging
from typing import Optional, Tuple, List, Dict, Any
from open_webui.config import (
MILVUS_URI,
MILVUS_TOKEN,
MILVUS_DB,
MILVUS_COLLECTION_PREFIX,
MILVUS_INDEX_TYPE,
MILVUS_METRIC_TYPE,
MILVUS_HNSW_M,
MILVUS_HNSW_EFCONSTRUCTION,
MILVUS_IVF_FLAT_NLIST,
)
from open_webui.retrieval.vector.main import (
GetResult,
SearchResult,
VectorDBBase,
VectorItem,
)
from pymilvus import (
connections,
utility,
Collection,
CollectionSchema,
FieldSchema,
DataType,
)
log = logging.getLogger(__name__)
RESOURCE_ID_FIELD = "resource_id"
class MilvusClient(VectorDBBase):
def __init__(self):
# Milvus collection names can only contain numbers, letters, and underscores.
self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_")
connections.connect(
alias="default",
uri=MILVUS_URI,
token=MILVUS_TOKEN,
db_name=MILVUS_DB,
)
# Main collection types for multi-tenancy
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
self.FILE_COLLECTION = f"{self.collection_prefix}_files"
self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search"
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based"
self.shared_collections = [
self.MEMORY_COLLECTION,
self.KNOWLEDGE_COLLECTION,
self.FILE_COLLECTION,
self.WEB_SEARCH_COLLECTION,
self.HASH_BASED_COLLECTION,
]
def _get_collection_and_resource_id(self, collection_name: str) -> Tuple[str, str]:
"""
Maps the traditional collection name to multi-tenant collection and resource ID.
WARNING: This mapping relies on current Open WebUI naming conventions for
collection names. If Open WebUI changes how it generates collection names
(e.g., "user-memory-" prefix, "file-" prefix, web search patterns, or hash
formats), this mapping will break and route data to incorrect collections.
POTENTIALLY CAUSING HUGE DATA CORRUPTION, DATA CONSISTENCY ISSUES AND INCORRECT
DATA MAPPING INSIDE THE DATABASE.
"""
resource_id = collection_name
if collection_name.startswith("user-memory-"):
return self.MEMORY_COLLECTION, resource_id
elif collection_name.startswith("file-"):
return self.FILE_COLLECTION, resource_id
elif collection_name.startswith("web-search-"):
return self.WEB_SEARCH_COLLECTION, resource_id
elif len(collection_name) == 63 and all(
c in "0123456789abcdef" for c in collection_name
):
return self.HASH_BASED_COLLECTION, resource_id
else:
return self.KNOWLEDGE_COLLECTION, resource_id
def _create_shared_collection(self, mt_collection_name: str, dimension: int):
fields = [
FieldSchema(
name="id",
dtype=DataType.VARCHAR,
is_primary=True,
auto_id=False,
max_length=36,
),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="metadata", dtype=DataType.JSON),
FieldSchema(name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255),
]
schema = CollectionSchema(fields, "Shared collection for multi-tenancy")
collection = Collection(mt_collection_name, schema)
index_params = {
"metric_type": MILVUS_METRIC_TYPE,
"index_type": MILVUS_INDEX_TYPE,
"params": {},
}
if MILVUS_INDEX_TYPE == "HNSW":
index_params["params"] = {
"M": MILVUS_HNSW_M,
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
}
elif MILVUS_INDEX_TYPE == "IVF_FLAT":
index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST}
collection.create_index("vector", index_params)
collection.create_index(RESOURCE_ID_FIELD)
log.info(f"Created shared collection: {mt_collection_name}")
return collection
def _ensure_collection(self, mt_collection_name: str, dimension: int):
if not utility.has_collection(mt_collection_name):
self._create_shared_collection(mt_collection_name, dimension)
def has_collection(self, collection_name: str) -> bool:
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
if not utility.has_collection(mt_collection):
return False
collection = Collection(mt_collection)
collection.load()
res = collection.query(expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", limit=1)
return len(res) > 0
def upsert(self, collection_name: str, items: List[VectorItem]):
if not items:
return
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
dimension = len(items[0]["vector"])
self._ensure_collection(mt_collection, dimension)
collection = Collection(mt_collection)
entities = [
{
"id": item["id"],
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
RESOURCE_ID_FIELD: resource_id,
}
for item in items
]
collection.insert(entities)
def search(
self, collection_name: str, vectors: List[List[float]], limit: int
) -> Optional[SearchResult]:
if not vectors:
return None
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
if not utility.has_collection(mt_collection):
return None
collection = Collection(mt_collection)
collection.load()
search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}}
results = collection.search(
data=vectors,
anns_field="vector",
param=search_params,
limit=limit,
expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'",
output_fields=["id", "text", "metadata"],
)
ids, documents, metadatas, distances = [], [], [], []
for hits in results:
batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], []
for hit in hits:
batch_ids.append(hit.entity.get("id"))
batch_docs.append(hit.entity.get("text"))
batch_metadatas.append(hit.entity.get("metadata"))
batch_dists.append(hit.distance)
ids.append(batch_ids)
documents.append(batch_docs)
metadatas.append(batch_metadatas)
distances.append(batch_dists)
return SearchResult(
ids=ids, documents=documents, metadatas=metadatas, distances=distances
)
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict[str, Any]] = None,
):
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
if not utility.has_collection(mt_collection):
return
collection = Collection(mt_collection)
# Build expression
expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
if ids:
# Milvus expects a string list for 'in' operator
id_list_str = ", ".join([f"'{id_val}'" for id_val in ids])
expr.append(f"id in [{id_list_str}]")
if filter:
for key, value in filter.items():
expr.append(f"metadata['{key}'] == '{value}'")
collection.delete(" and ".join(expr))
def reset(self):
for collection_name in self.shared_collections:
if utility.has_collection(collection_name):
utility.drop_collection(collection_name)
def delete_collection(self, collection_name: str):
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
if not utility.has_collection(mt_collection):
return
collection = Collection(mt_collection)
collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'")
def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]:
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
if not utility.has_collection(mt_collection):
return None
collection = Collection(mt_collection)
collection.load()
expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
if filter:
for key, value in filter.items():
if isinstance(value, str):
expr.append(f"metadata['{key}'] == '{value}'")
else:
expr.append(f"metadata['{key}'] == {value}")
iterator = collection.query_iterator(
expr=" and ".join(expr),
output_fields=["id", "text", "metadata"],
limit=limit if limit else -1,
)
all_results = []
while True:
batch = iterator.next()
if not batch:
iterator.close()
break
all_results.extend(batch)
ids = [res["id"] for res in all_results]
documents = [res["text"] for res in all_results]
metadatas = [res["metadata"] for res in all_results]
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
def get(self, collection_name: str) -> Optional[GetResult]:
return self.query(collection_name, filter={}, limit=None)
def insert(self, collection_name: str, items: List[VectorItem]):
return self.upsert(collection_name, items)

View file

@ -2,7 +2,6 @@ from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk from opensearchpy.helpers import bulk
from typing import Optional from typing import Optional
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.retrieval.vector.main import ( from open_webui.retrieval.vector.main import (
VectorDBBase, VectorDBBase,
VectorItem, VectorItem,
@ -158,10 +157,10 @@ class OpenSearchClient(VectorDBBase):
for field, value in filter.items(): for field, value in filter.items():
query_body["query"]["bool"]["filter"].append( query_body["query"]["bool"]["filter"].append(
{"term": {"metadata." + str(field) + ".keyword": value}} {"match": {"metadata." + str(field): value}}
) )
size = limit if limit else 10000 size = limit if limit else 10
try: try:
result = self.client.search( result = self.client.search(
@ -201,13 +200,12 @@ class OpenSearchClient(VectorDBBase):
"_source": { "_source": {
"vector": item["vector"], "vector": item["vector"],
"text": item["text"], "text": item["text"],
"metadata": process_metadata(item["metadata"]), "metadata": item["metadata"],
}, },
} }
for item in batch for item in batch
] ]
bulk(self.client, actions) bulk(self.client, actions)
self.client.indices.refresh(self._get_index_name(collection_name))
def upsert(self, collection_name: str, items: list[VectorItem]): def upsert(self, collection_name: str, items: list[VectorItem]):
self._create_index_if_not_exists( self._create_index_if_not_exists(
@ -223,14 +221,13 @@ class OpenSearchClient(VectorDBBase):
"doc": { "doc": {
"vector": item["vector"], "vector": item["vector"],
"text": item["text"], "text": item["text"],
"metadata": process_metadata(item["metadata"]), "metadata": item["metadata"],
}, },
"doc_as_upsert": True, "doc_as_upsert": True,
} }
for item in batch for item in batch
] ]
bulk(self.client, actions) bulk(self.client, actions)
self.client.indices.refresh(self._get_index_name(collection_name))
def delete( def delete(
self, self,
@ -254,12 +251,11 @@ class OpenSearchClient(VectorDBBase):
} }
for field, value in filter.items(): for field, value in filter.items():
query_body["query"]["bool"]["filter"].append( query_body["query"]["bool"]["filter"].append(
{"term": {"metadata." + str(field) + ".keyword": value}} {"match": {"metadata." + str(field): value}}
) )
self.client.delete_by_query( self.client.delete_by_query(
index=self._get_index_name(collection_name), body=query_body index=self._get_index_name(collection_name), body=query_body
) )
self.client.indices.refresh(self._get_index_name(collection_name))
def reset(self): def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}_*") indices = self.client.indices.get(index=f"{self.index_prefix}_*")

View file

@ -1,941 +0,0 @@
"""
Oracle 23ai Vector Database Client - Fixed Version
# .env
VECTOR_DB = "oracle23ai"
## DBCS or oracle 23ai free
ORACLE_DB_USE_WALLET = false
ORACLE_DB_USER = "DEMOUSER"
ORACLE_DB_PASSWORD = "Welcome123456"
ORACLE_DB_DSN = "localhost:1521/FREEPDB1"
## ADW or ATP
# ORACLE_DB_USE_WALLET = true
# ORACLE_DB_USER = "DEMOUSER"
# ORACLE_DB_PASSWORD = "Welcome123456"
# ORACLE_DB_DSN = "medium"
# ORACLE_DB_DSN = "(description= (retry_count=3)(retry_delay=3)(address=(protocol=tcps)(port=1522)(host=xx.oraclecloud.com))(connect_data=(service_name=yy.adb.oraclecloud.com))(security=(ssl_server_dn_match=no)))"
# ORACLE_WALLET_DIR = "/home/opc/adb_wallet"
# ORACLE_WALLET_PASSWORD = "Welcome1"
ORACLE_VECTOR_LENGTH = 768
ORACLE_DB_POOL_MIN = 2
ORACLE_DB_POOL_MAX = 10
ORACLE_DB_POOL_INCREMENT = 1
"""
from typing import Optional, List, Dict, Any, Union
from decimal import Decimal
import logging
import os
import threading
import time
import json
import array
import oracledb
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
ORACLE_DB_USE_WALLET,
ORACLE_DB_USER,
ORACLE_DB_PASSWORD,
ORACLE_DB_DSN,
ORACLE_WALLET_DIR,
ORACLE_WALLET_PASSWORD,
ORACLE_VECTOR_LENGTH,
ORACLE_DB_POOL_MIN,
ORACLE_DB_POOL_MAX,
ORACLE_DB_POOL_INCREMENT,
)
log = logging.getLogger(__name__)
class Oracle23aiClient(VectorDBBase):
"""
Oracle Vector Database Client for vector similarity search using Oracle Database 23ai.
This client provides an interface to store, retrieve, and search vector embeddings
in an Oracle database. It uses connection pooling for efficient database access
and supports vector similarity search operations.
Attributes:
pool: Connection pool for Oracle database connections
"""
def __init__(self) -> None:
"""
Initialize the Oracle23aiClient with a connection pool.
Creates a connection pool with configurable min/max connections, initializes
the database schema if needed, and sets up necessary tables and indexes.
Raises:
ValueError: If required configuration parameters are missing
Exception: If database initialization fails
"""
self.pool = None
try:
# Create the appropriate connection pool based on DB type
if ORACLE_DB_USE_WALLET:
self._create_adb_pool()
else: # DBCS
self._create_dbcs_pool()
dsn = ORACLE_DB_DSN
log.info(f"Creating Connection Pool [{ORACLE_DB_USER}:**@{dsn}]")
with self.get_connection() as connection:
log.info(f"Connection version: {connection.version}")
self._initialize_database(connection)
log.info("Oracle Vector Search initialization complete.")
except Exception as e:
log.exception(f"Error during Oracle Vector Search initialization: {e}")
raise
def _create_adb_pool(self) -> None:
"""
Create connection pool for Oracle Autonomous Database.
Uses wallet-based authentication.
"""
self.pool = oracledb.create_pool(
user=ORACLE_DB_USER,
password=ORACLE_DB_PASSWORD,
dsn=ORACLE_DB_DSN,
min=ORACLE_DB_POOL_MIN,
max=ORACLE_DB_POOL_MAX,
increment=ORACLE_DB_POOL_INCREMENT,
config_dir=ORACLE_WALLET_DIR,
wallet_location=ORACLE_WALLET_DIR,
wallet_password=ORACLE_WALLET_PASSWORD,
)
log.info("Created ADB connection pool with wallet authentication.")
def _create_dbcs_pool(self) -> None:
"""
Create connection pool for Oracle Database Cloud Service.
Uses basic authentication without wallet.
"""
self.pool = oracledb.create_pool(
user=ORACLE_DB_USER,
password=ORACLE_DB_PASSWORD,
dsn=ORACLE_DB_DSN,
min=ORACLE_DB_POOL_MIN,
max=ORACLE_DB_POOL_MAX,
increment=ORACLE_DB_POOL_INCREMENT,
)
log.info("Created DB connection pool with basic authentication.")
def get_connection(self):
"""
Acquire a connection from the connection pool with retry logic.
Returns:
connection: A database connection with output type handler configured
"""
max_retries = 3
for attempt in range(max_retries):
try:
connection = self.pool.acquire()
connection.outputtypehandler = self._output_type_handler
return connection
except oracledb.DatabaseError as e:
(error_obj,) = e.args
log.exception(
f"Connection attempt {attempt + 1} failed: {error_obj.message}"
)
if attempt < max_retries - 1:
wait_time = 2**attempt
log.info(f"Retrying in {wait_time} seconds...")
time.sleep(wait_time)
else:
raise
def start_health_monitor(self, interval_seconds: int = 60):
"""
Start a background thread to periodically check the health of the connection pool.
Args:
interval_seconds (int): Number of seconds between health checks
"""
def _monitor():
while True:
try:
log.info("[HealthCheck] Running periodic DB health check...")
self.ensure_connection()
log.info("[HealthCheck] Connection is healthy.")
except Exception as e:
log.exception(f"[HealthCheck] Connection health check failed: {e}")
time.sleep(interval_seconds)
thread = threading.Thread(target=_monitor, daemon=True)
thread.start()
log.info(f"Started DB health monitor every {interval_seconds} seconds.")
def _reconnect_pool(self):
"""
Attempt to reinitialize the connection pool if it's been closed or broken.
"""
try:
log.info("Attempting to reinitialize the Oracle connection pool...")
# Close existing pool if it exists
if self.pool:
try:
self.pool.close()
except Exception as close_error:
log.warning(f"Error closing existing pool: {close_error}")
# Re-create the appropriate connection pool based on DB type
if ORACLE_DB_USE_WALLET:
self._create_adb_pool()
else: # DBCS
self._create_dbcs_pool()
log.info("Connection pool reinitialized.")
except Exception as e:
log.exception(f"Failed to reinitialize the connection pool: {e}")
raise
def ensure_connection(self):
"""
Ensure the database connection is alive, reconnecting pool if needed.
"""
try:
with self.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("SELECT 1 FROM dual")
except Exception as e:
log.exception(
f"Connection check failed: {e}, attempting to reconnect pool..."
)
self._reconnect_pool()
def _output_type_handler(self, cursor, metadata):
"""
Handle Oracle vector type conversion.
Args:
cursor: Oracle database cursor
metadata: Metadata for the column
Returns:
A variable with appropriate conversion for vector types
"""
if metadata.type_code is oracledb.DB_TYPE_VECTOR:
return cursor.var(
metadata.type_code, arraysize=cursor.arraysize, outconverter=list
)
def _initialize_database(self, connection) -> None:
"""
Initialize database schema, tables and indexes.
Creates the document_chunk table and necessary indexes if they don't exist.
Args:
connection: Oracle database connection
Raises:
Exception: If schema initialization fails
"""
with connection.cursor() as cursor:
try:
log.info("Creating Table document_chunk")
cursor.execute(
"""
BEGIN
EXECUTE IMMEDIATE '
CREATE TABLE IF NOT EXISTS document_chunk (
id VARCHAR2(255) PRIMARY KEY,
collection_name VARCHAR2(255) NOT NULL,
text CLOB,
vmetadata JSON,
vector vector(*, float32)
)
';
EXCEPTION
WHEN OTHERS THEN
IF SQLCODE != -955 THEN
RAISE;
END IF;
END;
"""
)
log.info("Creating Index document_chunk_collection_name_idx")
cursor.execute(
"""
BEGIN
EXECUTE IMMEDIATE '
CREATE INDEX IF NOT EXISTS document_chunk_collection_name_idx
ON document_chunk (collection_name)
';
EXCEPTION
WHEN OTHERS THEN
IF SQLCODE != -955 THEN
RAISE;
END IF;
END;
"""
)
log.info("Creating VECTOR INDEX document_chunk_vector_ivf_idx")
cursor.execute(
"""
BEGIN
EXECUTE IMMEDIATE '
CREATE VECTOR INDEX IF NOT EXISTS document_chunk_vector_ivf_idx
ON document_chunk(vector)
ORGANIZATION NEIGHBOR PARTITIONS
DISTANCE COSINE
WITH TARGET ACCURACY 95
PARAMETERS (TYPE IVF, NEIGHBOR PARTITIONS 100)
';
EXCEPTION
WHEN OTHERS THEN
IF SQLCODE != -955 THEN
RAISE;
END IF;
END;
"""
)
connection.commit()
log.info("Database initialization completed successfully.")
except Exception as e:
connection.rollback()
log.exception(f"Error during database initialization: {e}")
raise
def check_vector_length(self) -> None:
"""
Check vector length compatibility (placeholder).
This method would check if the configured vector length matches the database schema.
Currently implemented as a placeholder.
"""
pass
def _vector_to_blob(self, vector: List[float]) -> bytes:
"""
Convert a vector to Oracle BLOB format.
Args:
vector (List[float]): The vector to convert
Returns:
bytes: The vector in Oracle BLOB format
"""
return array.array("f", vector)
def adjust_vector_length(self, vector: List[float]) -> List[float]:
"""
Adjust vector to the expected length if needed.
Args:
vector (List[float]): The vector to adjust
Returns:
List[float]: The adjusted vector
"""
return vector
def _decimal_handler(self, obj):
"""
Handle Decimal objects for JSON serialization.
Args:
obj: Object to serialize
Returns:
float: Converted decimal value
Raises:
TypeError: If object is not JSON serializable
"""
if isinstance(obj, Decimal):
return float(obj)
raise TypeError(f"{obj} is not JSON serializable")
def _metadata_to_json(self, metadata: Dict) -> str:
"""
Convert metadata dictionary to JSON string.
Args:
metadata (Dict): Metadata dictionary
Returns:
str: JSON representation of metadata
"""
return json.dumps(metadata, default=self._decimal_handler) if metadata else "{}"
def _json_to_metadata(self, json_str: str) -> Dict:
"""
Convert JSON string to metadata dictionary.
Args:
json_str (str): JSON string
Returns:
Dict: Metadata dictionary
"""
return json.loads(json_str) if json_str else {}
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
"""
Insert vector items into the database.
Args:
collection_name (str): Name of the collection
items (List[VectorItem]): List of vector items to insert
Raises:
Exception: If insertion fails
Example:
>>> client = Oracle23aiClient()
>>> items = [
... {"id": "1", "text": "Sample text", "vector": [0.1, 0.2, ...], "metadata": {"source": "doc1"}},
... {"id": "2", "text": "Another text", "vector": [0.3, 0.4, ...], "metadata": {"source": "doc2"}}
... ]
>>> client.insert("my_collection", items)
"""
log.info(f"Inserting {len(items)} items into collection '{collection_name}'.")
with self.get_connection() as connection:
try:
with connection.cursor() as cursor:
for item in items:
vector_blob = self._vector_to_blob(item["vector"])
metadata_json = self._metadata_to_json(item["metadata"])
cursor.execute(
"""
INSERT INTO document_chunk
(id, collection_name, text, vmetadata, vector)
VALUES (:id, :collection_name, :text, :metadata, :vector)
""",
{
"id": item["id"],
"collection_name": collection_name,
"text": item["text"],
"metadata": metadata_json,
"vector": vector_blob,
},
)
connection.commit()
log.info(
f"Successfully inserted {len(items)} items into collection '{collection_name}'."
)
except Exception as e:
connection.rollback()
log.exception(f"Error during insert: {e}")
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
"""
Update or insert vector items into the database.
If an item with the same ID exists, it will be updated;
otherwise, it will be inserted.
Args:
collection_name (str): Name of the collection
items (List[VectorItem]): List of vector items to upsert
Raises:
Exception: If upsert operation fails
Example:
>>> client = Oracle23aiClient()
>>> items = [
... {"id": "1", "text": "Updated text", "vector": [0.1, 0.2, ...], "metadata": {"source": "doc1"}},
... {"id": "3", "text": "New item", "vector": [0.5, 0.6, ...], "metadata": {"source": "doc3"}}
... ]
>>> client.upsert("my_collection", items)
"""
log.info(f"Upserting {len(items)} items into collection '{collection_name}'.")
with self.get_connection() as connection:
try:
with connection.cursor() as cursor:
for item in items:
vector_blob = self._vector_to_blob(item["vector"])
metadata_json = self._metadata_to_json(item["metadata"])
cursor.execute(
"""
MERGE INTO document_chunk d
USING (SELECT :merge_id as id FROM dual) s
ON (d.id = s.id)
WHEN MATCHED THEN
UPDATE SET
collection_name = :upd_collection_name,
text = :upd_text,
vmetadata = :upd_metadata,
vector = :upd_vector
WHEN NOT MATCHED THEN
INSERT (id, collection_name, text, vmetadata, vector)
VALUES (:ins_id, :ins_collection_name, :ins_text, :ins_metadata, :ins_vector)
""",
{
"merge_id": item["id"],
"upd_collection_name": collection_name,
"upd_text": item["text"],
"upd_metadata": metadata_json,
"upd_vector": vector_blob,
"ins_id": item["id"],
"ins_collection_name": collection_name,
"ins_text": item["text"],
"ins_metadata": metadata_json,
"ins_vector": vector_blob,
},
)
connection.commit()
log.info(
f"Successfully upserted {len(items)} items into collection '{collection_name}'."
)
except Exception as e:
connection.rollback()
log.exception(f"Error during upsert: {e}")
raise
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
) -> Optional[SearchResult]:
"""
Search for similar vectors in the database.
Performs vector similarity search using cosine distance.
Args:
collection_name (str): Name of the collection to search
vectors (List[List[Union[float, int]]]): Query vectors to find similar items for
limit (int): Maximum number of results to return per query
Returns:
Optional[SearchResult]: Search results containing ids, distances, documents, and metadata
Example:
>>> client = Oracle23aiClient()
>>> query_vector = [0.1, 0.2, 0.3, ...] # Must match VECTOR_LENGTH
>>> results = client.search("my_collection", [query_vector], limit=5)
>>> if results:
... log.info(f"Found {len(results.ids[0])} matches")
... for i, (id, dist) in enumerate(zip(results.ids[0], results.distances[0])):
... log.info(f"Match {i+1}: id={id}, distance={dist}")
"""
log.info(
f"Searching items from collection '{collection_name}' with limit {limit}."
)
try:
if not vectors:
log.warning("No vectors provided for search.")
return None
num_queries = len(vectors)
ids = [[] for _ in range(num_queries)]
distances = [[] for _ in range(num_queries)]
documents = [[] for _ in range(num_queries)]
metadatas = [[] for _ in range(num_queries)]
with self.get_connection() as connection:
with connection.cursor() as cursor:
for qid, vector in enumerate(vectors):
vector_blob = self._vector_to_blob(vector)
cursor.execute(
"""
SELECT dc.id, dc.text,
JSON_SERIALIZE(dc.vmetadata RETURNING VARCHAR2(4096)) as vmetadata,
VECTOR_DISTANCE(dc.vector, :query_vector, COSINE) as distance
FROM document_chunk dc
WHERE dc.collection_name = :collection_name
ORDER BY VECTOR_DISTANCE(dc.vector, :query_vector, COSINE)
FETCH APPROX FIRST :limit ROWS ONLY
""",
{
"query_vector": vector_blob,
"collection_name": collection_name,
"limit": limit,
},
)
results = cursor.fetchall()
for row in results:
ids[qid].append(row[0])
documents[qid].append(
row[1].read()
if isinstance(row[1], oracledb.LOB)
else str(row[1])
)
# 🔧 FIXED: Parse JSON metadata properly
metadata_str = (
row[2].read()
if isinstance(row[2], oracledb.LOB)
else row[2]
)
metadatas[qid].append(self._json_to_metadata(metadata_str))
distances[qid].append(float(row[3]))
log.info(
f"Search completed. Found {sum(len(ids[i]) for i in range(num_queries))} total results."
)
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
except Exception as e:
log.exception(f"Error during search: {e}")
return None
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
"""
Query items based on metadata filters.
Retrieves items that match specified metadata criteria.
Args:
collection_name (str): Name of the collection to query
filter (Dict[str, Any]): Metadata filters to apply
limit (Optional[int]): Maximum number of results to return
Returns:
Optional[GetResult]: Query results containing ids, documents, and metadata
Example:
>>> client = Oracle23aiClient()
>>> filter = {"source": "doc1", "category": "finance"}
>>> results = client.query("my_collection", filter, limit=20)
>>> if results:
... print(f"Found {len(results.ids[0])} matching documents")
"""
log.info(f"Querying items from collection '{collection_name}' with filters.")
try:
limit = limit or 100
query = """
SELECT id, text, JSON_SERIALIZE(vmetadata RETURNING VARCHAR2(4096)) as vmetadata
FROM document_chunk
WHERE collection_name = :collection_name
"""
params = {"collection_name": collection_name}
for i, (key, value) in enumerate(filter.items()):
param_name = f"value_{i}"
query += f" AND JSON_VALUE(vmetadata, '$.{key}' RETURNING VARCHAR2(4096)) = :{param_name}"
params[param_name] = str(value)
query += " FETCH FIRST :limit ROWS ONLY"
params["limit"] = limit
with self.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute(query, params)
results = cursor.fetchall()
if not results:
log.info("No results found for query.")
return None
ids = [[row[0] for row in results]]
documents = [
[
row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1])
for row in results
]
]
# 🔧 FIXED: Parse JSON metadata properly
metadatas = [
[
self._json_to_metadata(
row[2].read() if isinstance(row[2], oracledb.LOB) else row[2]
)
for row in results
]
]
log.info(f"Query completed. Found {len(results)} results.")
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
log.exception(f"Error during query: {e}")
return None
def get(self, collection_name: str) -> Optional[GetResult]:
"""
Get all items in a collection.
Retrieves items from a specified collection up to the limit.
Args:
collection_name (str): Name of the collection to retrieve
limit (Optional[int]): Maximum number of items to retrieve
Returns:
Optional[GetResult]: Result containing ids, documents, and metadata
Example:
>>> client = Oracle23aiClient()
>>> results = client.get("my_collection", limit=50)
>>> if results:
... print(f"Retrieved {len(results.ids[0])} documents from collection")
"""
log.info(
f"Getting items from collection '{collection_name}' with limit {limit}."
)
try:
limit = 1000 # Hardcoded limit for get operation
with self.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute(
"""
SELECT /*+ MONITOR */ id, text, JSON_SERIALIZE(vmetadata RETURNING VARCHAR2(4096)) as vmetadata
FROM document_chunk
WHERE collection_name = :collection_name
FETCH FIRST :limit ROWS ONLY
""",
{"collection_name": collection_name, "limit": limit},
)
results = cursor.fetchall()
if not results:
log.info("No results found.")
return None
ids = [[row[0] for row in results]]
documents = [
[
row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1])
for row in results
]
]
# 🔧 FIXED: Parse JSON metadata properly
metadatas = [
[
self._json_to_metadata(
row[2].read() if isinstance(row[2], oracledb.LOB) else row[2]
)
for row in results
]
]
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
log.exception(f"Error during get: {e}")
return None
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict[str, Any]] = None,
) -> None:
"""
Delete items from the database.
Deletes items from a collection based on IDs or metadata filters.
Args:
collection_name (str): Name of the collection to delete from
ids (Optional[List[str]]): Specific item IDs to delete
filter (Optional[Dict[str, Any]]): Metadata filters for deletion
Raises:
Exception: If deletion fails
Example:
>>> client = Oracle23aiClient()
>>> # Delete specific items by ID
>>> client.delete("my_collection", ids=["1", "3", "5"])
>>> # Or delete by metadata filter
>>> client.delete("my_collection", filter={"source": "deprecated_source"})
"""
log.info(f"Deleting items from collection '{collection_name}'.")
try:
query = (
"DELETE FROM document_chunk WHERE collection_name = :collection_name"
)
params = {"collection_name": collection_name}
if ids:
# 🔧 FIXED: Use proper parameterized query to prevent SQL injection
placeholders = ",".join([f":id_{i}" for i in range(len(ids))])
query += f" AND id IN ({placeholders})"
for i, id_val in enumerate(ids):
params[f"id_{i}"] = id_val
if filter:
for i, (key, value) in enumerate(filter.items()):
param_name = f"value_{i}"
query += f" AND JSON_VALUE(vmetadata, '$.{key}' RETURNING VARCHAR2(4096)) = :{param_name}"
params[param_name] = str(value)
with self.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute(query, params)
deleted = cursor.rowcount
connection.commit()
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
except Exception as e:
log.exception(f"Error during delete: {e}")
raise
def reset(self) -> None:
"""
Reset the database by deleting all items.
Deletes all items from the document_chunk table.
Raises:
Exception: If reset fails
Example:
>>> client = Oracle23aiClient()
>>> client.reset() # Warning: Removes all data!
"""
log.info("Resetting database - deleting all items.")
try:
with self.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("DELETE FROM document_chunk")
deleted = cursor.rowcount
connection.commit()
log.info(
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
)
except Exception as e:
log.exception(f"Error during reset: {e}")
raise
def close(self) -> None:
"""
Close the database connection pool.
Properly closes the connection pool and releases all resources.
Example:
>>> client = Oracle23aiClient()
>>> # After finishing all operations
>>> client.close()
"""
try:
if hasattr(self, "pool") and self.pool:
self.pool.close()
log.info("Oracle Vector Search connection pool closed.")
except Exception as e:
log.exception(f"Error closing connection pool: {e}")
def has_collection(self, collection_name: str) -> bool:
"""
Check if a collection exists.
Args:
collection_name (str): Name of the collection to check
Returns:
bool: True if the collection exists, False otherwise
Example:
>>> client = Oracle23aiClient()
>>> if client.has_collection("my_collection"):
... print("Collection exists!")
... else:
... print("Collection does not exist.")
"""
try:
with self.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute(
"""
SELECT COUNT(*)
FROM document_chunk
WHERE collection_name = :collection_name
FETCH FIRST 1 ROWS ONLY
""",
{"collection_name": collection_name},
)
count = cursor.fetchone()[0]
return count > 0
except Exception as e:
log.exception(f"Error checking collection existence: {e}")
return False
def delete_collection(self, collection_name: str) -> None:
"""
Delete an entire collection.
Removes all items belonging to the specified collection.
Args:
collection_name (str): Name of the collection to delete
Example:
>>> client = Oracle23aiClient()
>>> client.delete_collection("obsolete_collection")
"""
log.info(f"Deleting collection '{collection_name}'.")
try:
with self.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute(
"""
DELETE FROM document_chunk
WHERE collection_name = :collection_name
""",
{"collection_name": collection_name},
)
deleted = cursor.rowcount
connection.commit()
log.info(
f"Collection '{collection_name}' deleted. Removed {deleted} items."
)
except Exception as e:
log.exception(f"Error deleting collection '{collection_name}': {e}")
raise

View file

@ -1,16 +1,12 @@
from typing import Optional, List, Dict, Any, Tuple from typing import Optional, List, Dict, Any
import logging import logging
import json
from sqlalchemy import ( from sqlalchemy import (
func,
literal,
cast, cast,
column, column,
create_engine, create_engine,
Column, Column,
Integer, Integer,
MetaData, MetaData,
LargeBinary,
select, select,
text, text,
Text, Text,
@ -18,71 +14,39 @@ from sqlalchemy import (
values, values,
) )
from sqlalchemy.sql import true from sqlalchemy.sql import true
from sqlalchemy.pool import NullPool, QueuePool from sqlalchemy.pool import NullPool
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
from sqlalchemy.dialects.postgresql import JSONB, array from sqlalchemy.dialects.postgresql import JSONB, array
from pgvector.sqlalchemy import Vector, HALFVEC from pgvector.sqlalchemy import Vector
from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.exc import NoSuchTableError from sqlalchemy.exc import NoSuchTableError
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.retrieval.vector.main import ( from open_webui.retrieval.vector.main import (
VectorDBBase, VectorDBBase,
VectorItem, VectorItem,
SearchResult, SearchResult,
GetResult, GetResult,
) )
from open_webui.config import ( from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
PGVECTOR_DB_URL,
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
PGVECTOR_CREATE_EXTENSION,
PGVECTOR_PGCRYPTO,
PGVECTOR_PGCRYPTO_KEY,
PGVECTOR_POOL_SIZE,
PGVECTOR_POOL_MAX_OVERFLOW,
PGVECTOR_POOL_TIMEOUT,
PGVECTOR_POOL_RECYCLE,
PGVECTOR_INDEX_METHOD,
PGVECTOR_HNSW_M,
PGVECTOR_HNSW_EF_CONSTRUCTION,
PGVECTOR_IVFFLAT_LISTS,
PGVECTOR_USE_HALFVEC,
)
from open_webui.env import SRC_LOG_LEVELS
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
USE_HALFVEC = PGVECTOR_USE_HALFVEC
VECTOR_TYPE_FACTORY = HALFVEC if USE_HALFVEC else Vector
VECTOR_OPCLASS = "halfvec_cosine_ops" if USE_HALFVEC else "vector_cosine_ops"
Base = declarative_base() Base = declarative_base()
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def pgcrypto_encrypt(val, key):
return func.pgp_sym_encrypt(val, literal(key))
def pgcrypto_decrypt(col, key, outtype="text"):
return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
class DocumentChunk(Base): class DocumentChunk(Base):
__tablename__ = "document_chunk" __tablename__ = "document_chunk"
id = Column(Text, primary_key=True) id = Column(Text, primary_key=True)
vector = Column(VECTOR_TYPE_FACTORY(dim=VECTOR_LENGTH), nullable=True) vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
collection_name = Column(Text, nullable=False) collection_name = Column(Text, nullable=False)
text = Column(Text, nullable=True)
if PGVECTOR_PGCRYPTO: vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
text = Column(LargeBinary, nullable=True)
vmetadata = Column(LargeBinary, nullable=True)
else:
text = Column(Text, nullable=True)
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
class PgvectorClient(VectorDBBase): class PgvectorClient(VectorDBBase):
@ -94,24 +58,9 @@ class PgvectorClient(VectorDBBase):
self.session = Session self.session = Session
else: else:
if isinstance(PGVECTOR_POOL_SIZE, int): engine = create_engine(
if PGVECTOR_POOL_SIZE > 0: PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
engine = create_engine( )
PGVECTOR_DB_URL,
pool_size=PGVECTOR_POOL_SIZE,
max_overflow=PGVECTOR_POOL_MAX_OVERFLOW,
pool_timeout=PGVECTOR_POOL_TIMEOUT,
pool_recycle=PGVECTOR_POOL_RECYCLE,
pool_pre_ping=True,
poolclass=QueuePool,
)
else:
engine = create_engine(
PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
)
else:
engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True)
SessionLocal = sessionmaker( SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
) )
@ -119,41 +68,7 @@ class PgvectorClient(VectorDBBase):
try: try:
# Ensure the pgvector extension is available # Ensure the pgvector extension is available
# Use a conditional check to avoid permission issues on Azure PostgreSQL self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
if PGVECTOR_CREATE_EXTENSION:
self.session.execute(
text(
"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN
CREATE EXTENSION IF NOT EXISTS vector;
END IF;
END $$;
"""
)
)
if PGVECTOR_PGCRYPTO:
# Ensure the pgcrypto extension is available for encryption
# Use a conditional check to avoid permission issues on Azure PostgreSQL
self.session.execute(
text(
"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN
CREATE EXTENSION IF NOT EXISTS pgcrypto;
END IF;
END $$;
"""
)
)
if not PGVECTOR_PGCRYPTO_KEY:
raise ValueError(
"PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled."
)
# Check vector length consistency # Check vector length consistency
self.check_vector_length() self.check_vector_length()
@ -164,9 +79,13 @@ class PgvectorClient(VectorDBBase):
connection = self.session.connection() connection = self.session.connection()
Base.metadata.create_all(bind=connection) Base.metadata.create_all(bind=connection)
index_method, index_options = self._vector_index_configuration() # Create an index on the vector column if it doesn't exist
self._ensure_vector_index(index_method, index_options) self.session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
"ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
)
)
self.session.execute( self.session.execute(
text( text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name " "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
@ -180,78 +99,6 @@ class PgvectorClient(VectorDBBase):
log.exception(f"Error during initialization: {e}") log.exception(f"Error during initialization: {e}")
raise raise
@staticmethod
def _extract_index_method(index_def: Optional[str]) -> Optional[str]:
if not index_def:
return None
try:
after_using = index_def.lower().split("using ", 1)[1]
return after_using.split()[0]
except (IndexError, AttributeError):
return None
def _vector_index_configuration(self) -> Tuple[str, str]:
if PGVECTOR_INDEX_METHOD:
index_method = PGVECTOR_INDEX_METHOD
log.info(
"Using vector index method '%s' from PGVECTOR_INDEX_METHOD.",
index_method,
)
elif USE_HALFVEC:
index_method = "hnsw"
log.info(
"VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.",
VECTOR_LENGTH,
)
else:
index_method = "ivfflat"
if index_method == "hnsw":
index_options = f"WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})"
else:
index_options = f"WITH (lists = {PGVECTOR_IVFFLAT_LISTS})"
return index_method, index_options
def _ensure_vector_index(self, index_method: str, index_options: str) -> None:
index_name = "idx_document_chunk_vector"
existing_index_def = self.session.execute(
text(
"""
SELECT indexdef
FROM pg_indexes
WHERE schemaname = current_schema()
AND tablename = 'document_chunk'
AND indexname = :index_name
"""
),
{"index_name": index_name},
).scalar()
existing_method = self._extract_index_method(existing_index_def)
if existing_method and existing_method != index_method:
raise RuntimeError(
f"Existing pgvector index '{index_name}' uses method '{existing_method}' but configuration now "
f"requires '{index_method}'. Automatic rebuild is disabled to prevent long-running maintenance. "
"Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) "
"and recreate it with the new method before restarting Open WebUI."
)
if not existing_index_def:
index_sql = (
f"CREATE INDEX IF NOT EXISTS {index_name} "
f"ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})"
)
if index_options:
index_sql = f"{index_sql} {index_options}"
self.session.execute(text(index_sql))
log.info(
"Ensured vector index '%s' using %s%s.",
index_name,
index_method,
f" {index_options}" if index_options else "",
)
def check_vector_length(self) -> None: def check_vector_length(self) -> None:
""" """
Check if the VECTOR_LENGTH matches the existing vector column dimension in the database. Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
@ -271,19 +118,16 @@ class PgvectorClient(VectorDBBase):
if "vector" in document_chunk_table.columns: if "vector" in document_chunk_table.columns:
vector_column = document_chunk_table.columns["vector"] vector_column = document_chunk_table.columns["vector"]
vector_type = vector_column.type vector_type = vector_column.type
expected_type = HALFVEC if USE_HALFVEC else Vector if isinstance(vector_type, Vector):
db_vector_length = vector_type.dim
if not isinstance(vector_type, expected_type): if db_vector_length != VECTOR_LENGTH:
raise Exception(
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
"Cannot change vector size after initialization without migrating the data."
)
else:
raise Exception( raise Exception(
"The 'vector' column type does not match the expected type " "The 'vector' column exists but is not of type 'Vector'."
f"('{expected_type.__name__}') for VECTOR_LENGTH {VECTOR_LENGTH}."
)
db_vector_length = getattr(vector_type, "dim", None)
if db_vector_length is not None and db_vector_length != VECTOR_LENGTH:
raise Exception(
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
"Cannot change vector size after initialization without migrating the data."
) )
else: else:
raise Exception( raise Exception(
@ -303,54 +147,22 @@ class PgvectorClient(VectorDBBase):
def insert(self, collection_name: str, items: List[VectorItem]) -> None: def insert(self, collection_name: str, items: List[VectorItem]) -> None:
try: try:
if PGVECTOR_PGCRYPTO: new_items = []
for item in items: for item in items:
vector = self.adjust_vector_length(item["vector"]) vector = self.adjust_vector_length(item["vector"])
# Use raw SQL for BYTEA/pgcrypto new_chunk = DocumentChunk(
# Ensure metadata is converted to its JSON text representation id=item["id"],
json_metadata = json.dumps(item["metadata"]) vector=vector,
self.session.execute( collection_name=collection_name,
text( text=item["text"],
""" vmetadata=item["metadata"],
INSERT INTO document_chunk
(id, vector, collection_name, text, vmetadata)
VALUES (
:id, :vector, :collection_name,
pgp_sym_encrypt(:text, :key),
pgp_sym_encrypt(:metadata_text, :key)
)
ON CONFLICT (id) DO NOTHING
"""
),
{
"id": item["id"],
"vector": vector,
"collection_name": collection_name,
"text": item["text"],
"metadata_text": json_metadata,
"key": PGVECTOR_PGCRYPTO_KEY,
},
)
self.session.commit()
log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'")
else:
new_items = []
for item in items:
vector = self.adjust_vector_length(item["vector"])
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=process_metadata(item["metadata"]),
)
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
log.info(
f"Inserted {len(new_items)} items into collection '{collection_name}'."
) )
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
log.info(
f"Inserted {len(new_items)} items into collection '{collection_name}'."
)
except Exception as e: except Exception as e:
self.session.rollback() self.session.rollback()
log.exception(f"Error during insert: {e}") log.exception(f"Error during insert: {e}")
@ -358,66 +170,33 @@ class PgvectorClient(VectorDBBase):
def upsert(self, collection_name: str, items: List[VectorItem]) -> None: def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
try: try:
if PGVECTOR_PGCRYPTO: for item in items:
for item in items: vector = self.adjust_vector_length(item["vector"])
vector = self.adjust_vector_length(item["vector"]) existing = (
json_metadata = json.dumps(item["metadata"]) self.session.query(DocumentChunk)
self.session.execute( .filter(DocumentChunk.id == item["id"])
text( .first()
"""
INSERT INTO document_chunk
(id, vector, collection_name, text, vmetadata)
VALUES (
:id, :vector, :collection_name,
pgp_sym_encrypt(:text, :key),
pgp_sym_encrypt(:metadata_text, :key)
)
ON CONFLICT (id) DO UPDATE SET
vector = EXCLUDED.vector,
collection_name = EXCLUDED.collection_name,
text = EXCLUDED.text,
vmetadata = EXCLUDED.vmetadata
"""
),
{
"id": item["id"],
"vector": vector,
"collection_name": collection_name,
"text": item["text"],
"metadata_text": json_metadata,
"key": PGVECTOR_PGCRYPTO_KEY,
},
)
self.session.commit()
log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
else:
for item in items:
vector = self.adjust_vector_length(item["vector"])
existing = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.id == item["id"])
.first()
)
if existing:
existing.vector = vector
existing.text = item["text"]
existing.vmetadata = process_metadata(item["metadata"])
existing.collection_name = (
collection_name # Update collection_name if necessary
)
else:
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=process_metadata(item["metadata"]),
)
self.session.add(new_chunk)
self.session.commit()
log.info(
f"Upserted {len(items)} items into collection '{collection_name}'."
) )
if existing:
existing.vector = vector
existing.text = item["text"]
existing.vmetadata = item["metadata"]
existing.collection_name = (
collection_name # Update collection_name if necessary
)
else:
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=item["metadata"],
)
self.session.add(new_chunk)
self.session.commit()
log.info(
f"Upserted {len(items)} items into collection '{collection_name}'."
)
except Exception as e: except Exception as e:
self.session.rollback() self.session.rollback()
log.exception(f"Error during upsert: {e}") log.exception(f"Error during upsert: {e}")
@ -438,11 +217,11 @@ class PgvectorClient(VectorDBBase):
num_queries = len(vectors) num_queries = len(vectors)
def vector_expr(vector): def vector_expr(vector):
return cast(array(vector), VECTOR_TYPE_FACTORY(VECTOR_LENGTH)) return cast(array(vector), Vector(VECTOR_LENGTH))
# Create the values for query vectors # Create the values for query vectors
qid_col = column("qid", Integer) qid_col = column("qid", Integer)
q_vector_col = column("q_vector", VECTOR_TYPE_FACTORY(VECTOR_LENGTH)) q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
query_vectors = ( query_vectors = (
values(qid_col, q_vector_col) values(qid_col, q_vector_col)
.data( .data(
@ -451,32 +230,16 @@ class PgvectorClient(VectorDBBase):
.alias("query_vectors") .alias("query_vectors")
) )
result_fields = [
DocumentChunk.id,
]
if PGVECTOR_PGCRYPTO:
result_fields.append(
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text")
)
result_fields.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata")
)
else:
result_fields.append(DocumentChunk.text)
result_fields.append(DocumentChunk.vmetadata)
result_fields.append(
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
"distance"
)
)
# Build the lateral subquery for each query vector # Build the lateral subquery for each query vector
subq = ( subq = (
select(*result_fields) select(
DocumentChunk.id,
DocumentChunk.text,
DocumentChunk.vmetadata,
(
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
).label("distance"),
)
.where(DocumentChunk.collection_name == collection_name) .where(DocumentChunk.collection_name == collection_name)
.order_by( .order_by(
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)) (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
@ -525,12 +288,10 @@ class PgvectorClient(VectorDBBase):
documents[qid].append(row.text) documents[qid].append(row.text)
metadatas[qid].append(row.vmetadata) metadatas[qid].append(row.vmetadata)
self.session.rollback() # read-only transaction
return SearchResult( return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas ids=ids, distances=distances, documents=documents, metadatas=metadatas
) )
except Exception as e: except Exception as e:
self.session.rollback()
log.exception(f"Error during search: {e}") log.exception(f"Error during search: {e}")
return None return None
@ -538,43 +299,17 @@ class PgvectorClient(VectorDBBase):
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]: ) -> Optional[GetResult]:
try: try:
if PGVECTOR_PGCRYPTO: query = self.session.query(DocumentChunk).filter(
# Build where clause for vmetadata filter DocumentChunk.collection_name == collection_name
where_clauses = [DocumentChunk.collection_name == collection_name] )
for key, value in filter.items():
# decrypt then check key: JSON filter after decryption
where_clauses.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
)[key].astext
== str(value)
)
stmt = select(
DocumentChunk.id,
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text"),
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata"),
).where(*where_clauses)
if limit is not None:
stmt = stmt.limit(limit)
results = self.session.execute(stmt).all()
else:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
for key, value in filter.items(): for key, value in filter.items():
query = query.filter( query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
DocumentChunk.vmetadata[key].astext == str(value)
)
if limit is not None: if limit is not None:
query = query.limit(limit) query = query.limit(limit)
results = query.all() results = query.all()
if not results: if not results:
return None return None
@ -583,14 +318,12 @@ class PgvectorClient(VectorDBBase):
documents = [[result.text for result in results]] documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]] metadatas = [[result.vmetadata for result in results]]
self.session.rollback() # read-only transaction
return GetResult( return GetResult(
ids=ids, ids=ids,
documents=documents, documents=documents,
metadatas=metadatas, metadatas=metadatas,
) )
except Exception as e: except Exception as e:
self.session.rollback()
log.exception(f"Error during query: {e}") log.exception(f"Error during query: {e}")
return None return None
@ -598,43 +331,23 @@ class PgvectorClient(VectorDBBase):
self, collection_name: str, limit: Optional[int] = None self, collection_name: str, limit: Optional[int] = None
) -> Optional[GetResult]: ) -> Optional[GetResult]:
try: try:
if PGVECTOR_PGCRYPTO: query = self.session.query(DocumentChunk).filter(
stmt = select( DocumentChunk.collection_name == collection_name
DocumentChunk.id, )
pgcrypto_decrypt( if limit is not None:
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text query = query.limit(limit)
).label("text"),
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata"),
).where(DocumentChunk.collection_name == collection_name)
if limit is not None:
stmt = stmt.limit(limit)
results = self.session.execute(stmt).all()
ids = [[row.id for row in results]]
documents = [[row.text for row in results]]
metadatas = [[row.vmetadata for row in results]]
else:
query = self.session.query(DocumentChunk).filter( results = query.all()
DocumentChunk.collection_name == collection_name
)
if limit is not None:
query = query.limit(limit)
results = query.all() if not results:
return None
if not results: ids = [[result.id for result in results]]
return None documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]]
ids = [[result.id for result in results]]
documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]]
self.session.rollback() # read-only transaction
return GetResult(ids=ids, documents=documents, metadatas=metadatas) return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e: except Exception as e:
self.session.rollback()
log.exception(f"Error during get: {e}") log.exception(f"Error during get: {e}")
return None return None
@ -645,33 +358,17 @@ class PgvectorClient(VectorDBBase):
filter: Optional[Dict[str, Any]] = None, filter: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
try: try:
if PGVECTOR_PGCRYPTO: query = self.session.query(DocumentChunk).filter(
wheres = [DocumentChunk.collection_name == collection_name] DocumentChunk.collection_name == collection_name
if ids: )
wheres.append(DocumentChunk.id.in_(ids)) if ids:
if filter: query = query.filter(DocumentChunk.id.in_(ids))
for key, value in filter.items(): if filter:
wheres.append( for key, value in filter.items():
pgcrypto_decrypt( query = query.filter(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB DocumentChunk.vmetadata[key].astext == str(value)
)[key].astext )
== str(value) deleted = query.delete(synchronize_session=False)
)
stmt = DocumentChunk.__table__.delete().where(*wheres)
result = self.session.execute(stmt)
deleted = result.rowcount
else:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if ids:
query = query.filter(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
deleted = query.delete(synchronize_session=False)
self.session.commit() self.session.commit()
log.info(f"Deleted {deleted} items from collection '{collection_name}'.") log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
except Exception as e: except Exception as e:
@ -702,10 +399,8 @@ class PgvectorClient(VectorDBBase):
.first() .first()
is not None is not None
) )
self.session.rollback() # read-only transaction
return exists return exists
except Exception as e: except Exception as e:
self.session.rollback()
log.exception(f"Error checking collection existence: {e}") log.exception(f"Error checking collection existence: {e}")
return False return False

View file

@ -3,19 +3,10 @@ import logging
import time # for measuring elapsed time import time # for measuring elapsed time
from pinecone import Pinecone, ServerlessSpec from pinecone import Pinecone, ServerlessSpec
# Add gRPC support for better performance (Pinecone best practice)
try:
from pinecone.grpc import PineconeGRPC
GRPC_AVAILABLE = True
except ImportError:
GRPC_AVAILABLE = False
import asyncio # for async upserts import asyncio # for async upserts
import functools # for partial binding in async tasks import functools # for partial binding in async tasks
import concurrent.futures # for parallel batch upserts import concurrent.futures # for parallel batch upserts
import random # for jitter in retry backoff
from open_webui.retrieval.vector.main import ( from open_webui.retrieval.vector.main import (
VectorDBBase, VectorDBBase,
@ -31,13 +22,13 @@ from open_webui.config import (
PINECONE_METRIC, PINECONE_METRIC,
PINECONE_CLOUD, PINECONE_CLOUD,
) )
from open_webui.retrieval.vector.utils import process_metadata from open_webui.env import SRC_LOG_LEVELS
NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
BATCH_SIZE = 100 # Recommended batch size for Pinecone operations BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class PineconeClient(VectorDBBase): class PineconeClient(VectorDBBase):
@ -56,24 +47,7 @@ class PineconeClient(VectorDBBase):
self.cloud = PINECONE_CLOUD self.cloud = PINECONE_CLOUD
# Initialize Pinecone client for improved performance # Initialize Pinecone client for improved performance
if GRPC_AVAILABLE: self.client = Pinecone(api_key=self.api_key)
# Use gRPC client for better performance (Pinecone recommendation)
self.client = PineconeGRPC(
api_key=self.api_key,
pool_threads=20, # Improved connection pool size
timeout=30, # Reasonable timeout for operations
)
self.using_grpc = True
log.info("Using Pinecone gRPC client for optimal performance")
else:
# Fallback to HTTP client with enhanced connection pooling
self.client = Pinecone(
api_key=self.api_key,
pool_threads=20, # Improved connection pool size
timeout=30, # Reasonable timeout for operations
)
self.using_grpc = False
log.info("Using Pinecone HTTP client (gRPC not available)")
# Persistent executor for batch operations # Persistent executor for batch operations
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5) self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
@ -117,53 +91,12 @@ class PineconeClient(VectorDBBase):
log.info(f"Using existing Pinecone index '{self.index_name}'") log.info(f"Using existing Pinecone index '{self.index_name}'")
# Connect to the index # Connect to the index
self.index = self.client.Index( self.index = self.client.Index(self.index_name)
self.index_name,
pool_threads=20, # Enhanced connection pool for index operations
)
except Exception as e: except Exception as e:
log.error(f"Failed to initialize Pinecone index: {e}") log.error(f"Failed to initialize Pinecone index: {e}")
raise RuntimeError(f"Failed to initialize Pinecone index: {e}") raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
def _retry_pinecone_operation(self, operation_func, max_retries=3):
"""Retry Pinecone operations with exponential backoff for rate limits and network issues."""
for attempt in range(max_retries):
try:
return operation_func()
except Exception as e:
error_str = str(e).lower()
# Check if it's a retryable error (rate limits, network issues, timeouts)
is_retryable = any(
keyword in error_str
for keyword in [
"rate limit",
"quota",
"timeout",
"network",
"connection",
"unavailable",
"internal error",
"429",
"500",
"502",
"503",
"504",
]
)
if not is_retryable or attempt == max_retries - 1:
# Don't retry for non-retryable errors or on final attempt
raise
# Exponential backoff with jitter
delay = (2**attempt) + random.uniform(0, 1)
log.warning(
f"Pinecone operation failed (attempt {attempt + 1}/{max_retries}), "
f"retrying in {delay:.2f}s: {e}"
)
time.sleep(delay)
def _create_points( def _create_points(
self, items: List[VectorItem], collection_name_with_prefix: str self, items: List[VectorItem], collection_name_with_prefix: str
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
@ -183,7 +116,7 @@ class PineconeClient(VectorDBBase):
point = { point = {
"id": item["id"], "id": item["id"],
"values": item["vector"], "values": item["vector"],
"metadata": process_metadata(metadata), "metadata": metadata,
} }
points.append(point) points.append(point)
return points return points
@ -290,8 +223,7 @@ class PineconeClient(VectorDBBase):
elapsed = time.time() - start_time elapsed = time.time() - start_time
log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds") log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
log.info( log.info(
f"Successfully inserted {len(points)} vectors in parallel batches " f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
f"into '{collection_name_with_prefix}'"
) )
def upsert(self, collection_name: str, items: List[VectorItem]) -> None: def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
@ -322,8 +254,7 @@ class PineconeClient(VectorDBBase):
elapsed = time.time() - start_time elapsed = time.time() - start_time
log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds") log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
log.info( log.info(
f"Successfully upserted {len(points)} vectors in parallel batches " f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
f"into '{collection_name_with_prefix}'"
) )
async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None: async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
@ -354,8 +285,7 @@ class PineconeClient(VectorDBBase):
log.error(f"Error in async insert batch: {result}") log.error(f"Error in async insert batch: {result}")
raise result raise result
log.info( log.info(
f"Successfully async inserted {len(points)} vectors in batches " f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
f"into '{collection_name_with_prefix}'"
) )
async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None: async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
@ -386,8 +316,7 @@ class PineconeClient(VectorDBBase):
log.error(f"Error in async upsert batch: {result}") log.error(f"Error in async upsert batch: {result}")
raise result raise result
log.info( log.info(
f"Successfully async upserted {len(points)} vectors in batches " f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
f"into '{collection_name_with_prefix}'"
) )
def search( def search(
@ -528,12 +457,10 @@ class PineconeClient(VectorDBBase):
# This is a limitation of Pinecone - be careful with ID uniqueness # This is a limitation of Pinecone - be careful with ID uniqueness
self.index.delete(ids=batch_ids) self.index.delete(ids=batch_ids)
log.debug( log.debug(
f"Deleted batch of {len(batch_ids)} vectors by ID " f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
f"from '{collection_name_with_prefix}'"
) )
log.info( log.info(
f"Successfully deleted {len(ids)} vectors by ID " f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
f"from '{collection_name_with_prefix}'"
) )
elif filter: elif filter:

View file

@ -18,26 +18,23 @@ from open_webui.config import (
QDRANT_ON_DISK, QDRANT_ON_DISK,
QDRANT_GRPC_PORT, QDRANT_GRPC_PORT,
QDRANT_PREFER_GRPC, QDRANT_PREFER_GRPC,
QDRANT_COLLECTION_PREFIX,
QDRANT_TIMEOUT,
QDRANT_HNSW_M,
) )
from open_webui.env import SRC_LOG_LEVELS
NO_LIMIT = 999999999 NO_LIMIT = 999999999
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient(VectorDBBase): class QdrantClient(VectorDBBase):
def __init__(self): def __init__(self):
self.collection_prefix = QDRANT_COLLECTION_PREFIX self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY self.QDRANT_API_KEY = QDRANT_API_KEY
self.QDRANT_ON_DISK = QDRANT_ON_DISK self.QDRANT_ON_DISK = QDRANT_ON_DISK
self.PREFER_GRPC = QDRANT_PREFER_GRPC self.PREFER_GRPC = QDRANT_PREFER_GRPC
self.GRPC_PORT = QDRANT_GRPC_PORT self.GRPC_PORT = QDRANT_GRPC_PORT
self.QDRANT_TIMEOUT = QDRANT_TIMEOUT
self.QDRANT_HNSW_M = QDRANT_HNSW_M
if not self.QDRANT_URI: if not self.QDRANT_URI:
self.client = None self.client = None
@ -55,14 +52,9 @@ class QdrantClient(VectorDBBase):
grpc_port=self.GRPC_PORT, grpc_port=self.GRPC_PORT,
prefer_grpc=self.PREFER_GRPC, prefer_grpc=self.PREFER_GRPC,
api_key=self.QDRANT_API_KEY, api_key=self.QDRANT_API_KEY,
timeout=self.QDRANT_TIMEOUT,
) )
else: else:
self.client = Qclient( self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
url=self.QDRANT_URI,
api_key=self.QDRANT_API_KEY,
timeout=QDRANT_TIMEOUT,
)
def _result_to_get_result(self, points) -> GetResult: def _result_to_get_result(self, points) -> GetResult:
ids = [] ids = []
@ -92,30 +84,8 @@ class QdrantClient(VectorDBBase):
distance=models.Distance.COSINE, distance=models.Distance.COSINE,
on_disk=self.QDRANT_ON_DISK, on_disk=self.QDRANT_ON_DISK,
), ),
hnsw_config=models.HnswConfigDiff(
m=self.QDRANT_HNSW_M,
),
) )
# Create payload indexes for efficient filtering
self.client.create_payload_index(
collection_name=collection_name_with_prefix,
field_name="metadata.hash",
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=False,
on_disk=self.QDRANT_ON_DISK,
),
)
self.client.create_payload_index(
collection_name=collection_name_with_prefix,
field_name="metadata.file_id",
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=False,
on_disk=self.QDRANT_ON_DISK,
),
)
log.info(f"collection {collection_name_with_prefix} successfully created!") log.info(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension): def _create_collection_if_not_exists(self, collection_name, dimension):
@ -181,23 +151,23 @@ class QdrantClient(VectorDBBase):
) )
) )
points = self.client.scroll( points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
scroll_filter=models.Filter(should=field_conditions), query_filter=models.Filter(should=field_conditions),
limit=limit, limit=limit,
) )
return self._result_to_get_result(points[0]) return self._result_to_get_result(points.points)
except Exception as e: except Exception as e:
log.exception(f"Error querying a collection '{collection_name}': {e}") log.exception(f"Error querying a collection '{collection_name}': {e}")
return None return None
def get(self, collection_name: str) -> Optional[GetResult]: def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection. # Get all the items in the collection.
points = self.client.scroll( points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
limit=NO_LIMIT, # otherwise qdrant would set limit to 10! limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
) )
return self._result_to_get_result(points[0]) return self._result_to_get_result(points.points)
def insert(self, collection_name: str, items: list[VectorItem]): def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created. # Insert the items into the collection, if the collection does not exist, it will be created.

View file

@ -1,5 +1,5 @@
import logging import logging
from typing import Optional, Tuple, List, Dict, Any from typing import Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
import grpc import grpc
@ -9,10 +9,8 @@ from open_webui.config import (
QDRANT_ON_DISK, QDRANT_ON_DISK,
QDRANT_PREFER_GRPC, QDRANT_PREFER_GRPC,
QDRANT_URI, QDRANT_URI,
QDRANT_COLLECTION_PREFIX,
QDRANT_TIMEOUT,
QDRANT_HNSW_M,
) )
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.vector.main import ( from open_webui.retrieval.vector.main import (
GetResult, GetResult,
SearchResult, SearchResult,
@ -25,61 +23,39 @@ from qdrant_client.http.models import PointStruct
from qdrant_client.models import models from qdrant_client.models import models
NO_LIMIT = 999999999 NO_LIMIT = 999999999
TENANT_ID_FIELD = "tenant_id"
DEFAULT_DIMENSION = 384
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def _tenant_filter(tenant_id: str) -> models.FieldCondition:
return models.FieldCondition(
key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
)
def _metadata_filter(key: str, value: Any) -> models.FieldCondition:
return models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
class QdrantClient(VectorDBBase): class QdrantClient(VectorDBBase):
def __init__(self): def __init__(self):
self.collection_prefix = QDRANT_COLLECTION_PREFIX self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY self.QDRANT_API_KEY = QDRANT_API_KEY
self.QDRANT_ON_DISK = QDRANT_ON_DISK self.QDRANT_ON_DISK = QDRANT_ON_DISK
self.PREFER_GRPC = QDRANT_PREFER_GRPC self.PREFER_GRPC = QDRANT_PREFER_GRPC
self.GRPC_PORT = QDRANT_GRPC_PORT self.GRPC_PORT = QDRANT_GRPC_PORT
self.QDRANT_TIMEOUT = QDRANT_TIMEOUT
self.QDRANT_HNSW_M = QDRANT_HNSW_M
if not self.QDRANT_URI: if not self.QDRANT_URI:
raise ValueError( self.client = None
"QDRANT_URI is not set. Please configure it in the environment variables." return
)
# Unified handling for either scheme # Unified handling for either scheme
parsed = urlparse(self.QDRANT_URI) parsed = urlparse(self.QDRANT_URI)
host = parsed.hostname or self.QDRANT_URI host = parsed.hostname or self.QDRANT_URI
http_port = parsed.port or 6333 # default REST port http_port = parsed.port or 6333 # default REST port
self.client = ( if self.PREFER_GRPC:
Qclient( self.client = Qclient(
host=host, host=host,
port=http_port, port=http_port,
grpc_port=self.GRPC_PORT, grpc_port=self.GRPC_PORT,
prefer_grpc=self.PREFER_GRPC, prefer_grpc=self.PREFER_GRPC,
api_key=self.QDRANT_API_KEY, api_key=self.QDRANT_API_KEY,
timeout=self.QDRANT_TIMEOUT,
) )
if self.PREFER_GRPC else:
else Qclient( self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
url=self.QDRANT_URI,
api_key=self.QDRANT_API_KEY,
timeout=self.QDRANT_TIMEOUT,
)
)
# Main collection types for multi-tenancy # Main collection types for multi-tenancy
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories" self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
@ -89,13 +65,23 @@ class QdrantClient(VectorDBBase):
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based" self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
def _result_to_get_result(self, points) -> GetResult: def _result_to_get_result(self, points) -> GetResult:
ids, documents, metadatas = [], [], [] ids = []
documents = []
metadatas = []
for point in points: for point in points:
payload = point.payload payload = point.payload
ids.append(point.id) ids.append(point.id)
documents.append(payload["text"]) documents.append(payload["text"])
metadatas.append(payload["metadata"]) metadatas.append(payload["metadata"])
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]: def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
""" """
@ -103,13 +89,6 @@ class QdrantClient(VectorDBBase):
Returns: Returns:
tuple: (collection_name, tenant_id) tuple: (collection_name, tenant_id)
WARNING: This mapping relies on current Open WebUI naming conventions for
collection names. If Open WebUI changes how it generates collection names
(e.g., "user-memory-" prefix, "file-" prefix, web search patterns, or hash
formats), this mapping will break and route data to incorrect collections.
POTENTIALLY CAUSING HUGE DATA CORRUPTION, DATA CONSISTENCY ISSUES AND INCORRECT
DATA MAPPING INSIDE THE DATABASE.
""" """
# Check for user memory collections # Check for user memory collections
tenant_id = collection_name tenant_id = collection_name
@ -134,53 +113,143 @@ class QdrantClient(VectorDBBase):
else: else:
return self.KNOWLEDGE_COLLECTION, tenant_id return self.KNOWLEDGE_COLLECTION, tenant_id
def _create_multi_tenant_collection( def _extract_error_message(self, exception):
self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION """
Extract error message from either HTTP or gRPC exceptions
Returns:
tuple: (status_code, error_message)
"""
# Check if it's an HTTP exception
if isinstance(exception, UnexpectedResponse):
try:
error_data = exception.structured()
error_msg = error_data.get("status", {}).get("error", "")
return exception.status_code, error_msg
except Exception as inner_e:
log.error(f"Failed to parse HTTP error: {inner_e}")
return exception.status_code, str(exception)
# Check if it's a gRPC exception
elif isinstance(exception, grpc.RpcError):
# Extract status code from gRPC error
status_code = None
if hasattr(exception, "code") and callable(exception.code):
status_code = exception.code().value[0]
# Extract error message
error_msg = str(exception)
if "details =" in error_msg:
# Parse the details line which contains the actual error message
try:
details_line = [
line.strip()
for line in error_msg.split("\n")
if "details =" in line
][0]
error_msg = details_line.split("details =")[1].strip(' "')
except (IndexError, AttributeError):
# Fall back to full message if parsing fails
pass
return status_code, error_msg
# For any other type of exception
return None, str(exception)
def _is_collection_not_found_error(self, exception):
"""
Check if the exception is due to collection not found, supporting both HTTP and gRPC
"""
status_code, error_msg = self._extract_error_message(exception)
# HTTP error (404)
if (
status_code == 404
and "Collection" in error_msg
and "doesn't exist" in error_msg
):
return True
# gRPC error (NOT_FOUND status)
if (
isinstance(exception, grpc.RpcError)
and exception.code() == grpc.StatusCode.NOT_FOUND
):
return True
return False
def _is_dimension_mismatch_error(self, exception):
"""
Check if the exception is due to dimension mismatch, supporting both HTTP and gRPC
"""
status_code, error_msg = self._extract_error_message(exception)
# Common patterns in both HTTP and gRPC
return (
"Vector dimension error" in error_msg
or "dimensions mismatch" in error_msg
or "invalid vector size" in error_msg
)
def _create_multi_tenant_collection_if_not_exists(
self, mt_collection_name: str, dimension: int = 384
): ):
""" """
Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields. Creates a collection with multi-tenancy configuration if it doesn't exist.
Default dimension is set to 384 which corresponds to 'sentence-transformers/all-MiniLM-L6-v2'.
When creating collections dynamically (insert/upsert), the actual vector dimensions will be used.
""" """
self.client.create_collection( try:
collection_name=mt_collection_name, # Try to create the collection directly - will fail if it already exists
vectors_config=models.VectorParams( self.client.create_collection(
size=dimension,
distance=models.Distance.COSINE,
on_disk=self.QDRANT_ON_DISK,
),
# Disable global index building due to multitenancy
# For more details https://qdrant.tech/documentation/guides/multiple-partitions/#calibrate-performance
hnsw_config=models.HnswConfigDiff(
payload_m=self.QDRANT_HNSW_M,
m=0,
),
)
log.info(
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
)
self.client.create_payload_index(
collection_name=mt_collection_name,
field_name=TENANT_ID_FIELD,
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=True,
on_disk=self.QDRANT_ON_DISK,
),
)
for field in ("metadata.hash", "metadata.file_id"):
self.client.create_payload_index(
collection_name=mt_collection_name, collection_name=mt_collection_name,
field_name=field, vectors_config=models.VectorParams(
field_schema=models.KeywordIndexParams( size=dimension,
type=models.KeywordIndexType.KEYWORD, distance=models.Distance.COSINE,
on_disk=self.QDRANT_ON_DISK,
),
hnsw_config=models.HnswConfigDiff(
payload_m=16, # Enable per-tenant indexing
m=0,
on_disk=self.QDRANT_ON_DISK, on_disk=self.QDRANT_ON_DISK,
), ),
) )
def _create_points( # Create tenant ID payload index
self, items: List[VectorItem], tenant_id: str self.client.create_payload_index(
) -> List[PointStruct]: collection_name=mt_collection_name,
field_name="tenant_id",
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=True,
on_disk=self.QDRANT_ON_DISK,
),
wait=True,
)
log.info(
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
)
except (UnexpectedResponse, grpc.RpcError) as e:
# Check for the specific error indicating collection already exists
status_code, error_msg = self._extract_error_message(e)
# HTTP status code 409 or gRPC ALREADY_EXISTS
if (isinstance(e, UnexpectedResponse) and status_code == 409) or (
isinstance(e, grpc.RpcError)
and e.code() == grpc.StatusCode.ALREADY_EXISTS
):
if "already exists" in error_msg:
log.debug(f"Collection {mt_collection_name} already exists")
return
# If it's not an already exists error, re-raise
raise e
except Exception as e:
raise e
def _create_points(self, items: list[VectorItem], tenant_id: str):
""" """
Create point structs from vector items with tenant ID. Create point structs from vector items with tenant ID.
""" """
@ -191,42 +260,56 @@ class QdrantClient(VectorDBBase):
payload={ payload={
"text": item["text"], "text": item["text"],
"metadata": item["metadata"], "metadata": item["metadata"],
TENANT_ID_FIELD: tenant_id, "tenant_id": tenant_id,
}, },
) )
for item in items for item in items
] ]
def _ensure_collection(
self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
):
"""
Ensure the collection exists and payload indexes are created for tenant_id and metadata fields.
"""
if not self.client.collection_exists(collection_name=mt_collection_name):
self._create_multi_tenant_collection(mt_collection_name, dimension)
def has_collection(self, collection_name: str) -> bool: def has_collection(self, collection_name: str) -> bool:
""" """
Check if a logical collection exists by checking for any points with the tenant ID. Check if a logical collection exists by checking for any points with the tenant ID.
""" """
if not self.client: if not self.client:
return False return False
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection):
return False # Create tenant filter
tenant_filter = _tenant_filter(tenant_id) tenant_filter = models.FieldCondition(
count_result = self.client.count( key="tenant_id", match=models.MatchValue(value=tenant_id)
collection_name=mt_collection,
count_filter=models.Filter(must=[tenant_filter]),
) )
return count_result.count > 0
try:
# Try directly querying - most of the time collection should exist
response = self.client.query_points(
collection_name=mt_collection,
query_filter=models.Filter(must=[tenant_filter]),
limit=1,
)
# Collection exists with this tenant ID if there are points
return len(response.points) > 0
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(f"Collection {mt_collection} doesn't exist")
return False
else:
# For other API errors, log and return False
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error: {error_msg}")
return False
except Exception as e:
# For any other errors, log and return False
log.debug(f"Error checking collection {mt_collection}: {e}")
return False
def delete( def delete(
self, self,
collection_name: str, collection_name: str,
ids: Optional[List[str]] = None, ids: Optional[list[str]] = None,
filter: Optional[Dict[str, Any]] = None, filter: Optional[dict] = None,
): ):
""" """
Delete vectors by ID or filter from a collection with tenant isolation. Delete vectors by ID or filter from a collection with tenant isolation.
@ -234,76 +317,189 @@ class QdrantClient(VectorDBBase):
if not self.client: if not self.client:
return None return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection):
log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
return None
must_conditions = [_tenant_filter(tenant_id)] # Create tenant filter
should_conditions = [] tenant_filter = models.FieldCondition(
if ids: key="tenant_id", match=models.MatchValue(value=tenant_id)
should_conditions = [_metadata_filter("id", id_value) for id_value in ids]
elif filter:
must_conditions += [_metadata_filter(k, v) for k, v in filter.items()]
return self.client.delete(
collection_name=mt_collection,
points_selector=models.FilterSelector(
filter=models.Filter(must=must_conditions, should=should_conditions)
),
) )
must_conditions = [tenant_filter]
should_conditions = []
if ids:
for id_value in ids:
should_conditions.append(
models.FieldCondition(
key="metadata.id",
match=models.MatchValue(value=id_value),
),
)
elif filter:
for key, value in filter.items():
must_conditions.append(
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
)
try:
# Try to delete directly - most of the time collection should exist
update_result = self.client.delete(
collection_name=mt_collection,
points_selector=models.FilterSelector(
filter=models.Filter(must=must_conditions, should=should_conditions)
),
)
return update_result
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(
f"Collection {mt_collection} doesn't exist, nothing to delete"
)
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, re-raise
raise
def search( def search(
self, collection_name: str, vectors: List[List[float | int]], limit: int self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]: ) -> Optional[SearchResult]:
""" """
Search for the nearest neighbor items based on the vectors with tenant isolation. Search for the nearest neighbor items based on the vectors with tenant isolation.
""" """
if not self.client or not vectors: if not self.client:
return None return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection):
log.debug(f"Collection {mt_collection} doesn't exist, search returns None") # Get the vector dimension from the query vector
dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None
try:
# Try the search operation directly - most of the time collection should exist
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
# Ensure vector dimensions match the collection
collection_dim = self.client.get_collection(
mt_collection
).config.params.vectors.size
if collection_dim != dimension:
if collection_dim < dimension:
vectors = [vector[:collection_dim] for vector in vectors]
else:
vectors = [
vector + [0] * (collection_dim - dimension)
for vector in vectors
]
# Search with tenant filter
prefetch_query = models.Prefetch(
filter=models.Filter(must=[tenant_filter]),
limit=NO_LIMIT,
)
query_response = self.client.query_points(
collection_name=mt_collection,
query=vectors[0],
prefetch=prefetch_query,
limit=limit,
)
get_result = self._result_to_get_result(query_response.points)
return SearchResult(
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
# qdrant distance is [-1, 1], normalize to [0, 1]
distances=[
[(point.score + 1.0) / 2.0 for point in query_response.points]
],
)
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(
f"Collection {mt_collection} doesn't exist, search returns None"
)
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error during search: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, log and return None
log.exception(f"Error searching collection '{collection_name}': {e}")
return None return None
tenant_filter = _tenant_filter(tenant_id) def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
query_response = self.client.query_points(
collection_name=mt_collection,
query=vectors[0],
limit=limit,
query_filter=models.Filter(must=[tenant_filter]),
)
get_result = self._result_to_get_result(query_response.points)
return SearchResult(
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]],
)
def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
):
""" """
Query points with filters and tenant isolation. Query points with filters and tenant isolation.
""" """
if not self.client: if not self.client:
return None return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection):
log.debug(f"Collection {mt_collection} doesn't exist, query returns None") # Set default limit if not provided
return None
if limit is None: if limit is None:
limit = NO_LIMIT limit = NO_LIMIT
tenant_filter = _tenant_filter(tenant_id)
field_conditions = [_metadata_filter(k, v) for k, v in filter.items()] # Create tenant filter
combined_filter = models.Filter(must=[tenant_filter, *field_conditions]) tenant_filter = models.FieldCondition(
points = self.client.scroll( key="tenant_id", match=models.MatchValue(value=tenant_id)
collection_name=mt_collection,
scroll_filter=combined_filter,
limit=limit,
) )
return self._result_to_get_result(points[0])
# Create metadata filters
field_conditions = []
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
)
# Combine tenant filter with metadata filters
combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
try:
# Try the query directly - most of the time collection should exist
points = self.client.query_points(
collection_name=mt_collection,
query_filter=combined_filter,
limit=limit,
)
return self._result_to_get_result(points.points)
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(
f"Collection {mt_collection} doesn't exist, query returns None"
)
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error during query: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, log and re-raise
log.exception(f"Error querying collection '{collection_name}': {e}")
return None
def get(self, collection_name: str) -> Optional[GetResult]: def get(self, collection_name: str) -> Optional[GetResult]:
""" """
@ -311,36 +507,169 @@ class QdrantClient(VectorDBBase):
""" """
if not self.client: if not self.client:
return None return None
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection):
log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
return None
tenant_filter = _tenant_filter(tenant_id)
points = self.client.scroll(
collection_name=mt_collection,
scroll_filter=models.Filter(must=[tenant_filter]),
limit=NO_LIMIT,
)
return self._result_to_get_result(points[0])
def upsert(self, collection_name: str, items: List[VectorItem]): # Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
try:
# Try to get points directly - most of the time collection should exist
points = self.client.query_points(
collection_name=mt_collection,
query_filter=models.Filter(must=[tenant_filter]),
limit=NO_LIMIT,
)
return self._result_to_get_result(points.points)
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error during get: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, log and return None
log.exception(f"Error getting collection '{collection_name}': {e}")
return None
def _handle_operation_with_error_retry(
self, operation_name, mt_collection, points, dimension
):
"""
Private helper to handle common error cases for insert and upsert operations.
Args:
operation_name: 'insert' or 'upsert'
mt_collection: The multi-tenant collection name
points: The vector points to insert/upsert
dimension: The dimension of the vectors
Returns:
The operation result (for upsert) or None (for insert)
"""
try:
if operation_name == "insert":
self.client.upload_points(mt_collection, points)
return None
else: # upsert
return self.client.upsert(mt_collection, points)
except (UnexpectedResponse, grpc.RpcError) as e:
# Handle collection not found
if self._is_collection_not_found_error(e):
log.info(
f"Collection {mt_collection} doesn't exist. Creating it with dimension {dimension}."
)
# Create collection with correct dimensions from our vectors
self._create_multi_tenant_collection_if_not_exists(
mt_collection_name=mt_collection, dimension=dimension
)
# Try operation again - no need for dimension adjustment since we just created with correct dimensions
if operation_name == "insert":
self.client.upload_points(mt_collection, points)
return None
else: # upsert
return self.client.upsert(mt_collection, points)
# Handle dimension mismatch
elif self._is_dimension_mismatch_error(e):
# For dimension errors, the collection must exist, so get its configuration
mt_collection_info = self.client.get_collection(mt_collection)
existing_size = mt_collection_info.config.params.vectors.size
log.info(
f"Dimension mismatch: Collection {mt_collection} expects {existing_size}, got {dimension}"
)
if existing_size < dimension:
# Truncate vectors to fit
log.info(
f"Truncating vectors from {dimension} to {existing_size} dimensions"
)
points = [
PointStruct(
id=point.id,
vector=point.vector[:existing_size],
payload=point.payload,
)
for point in points
]
elif existing_size > dimension:
# Pad vectors with zeros
log.info(
f"Padding vectors from {dimension} to {existing_size} dimensions with zeros"
)
points = [
PointStruct(
id=point.id,
vector=point.vector
+ [0] * (existing_size - len(point.vector)),
payload=point.payload,
)
for point in points
]
# Try operation again with adjusted dimensions
if operation_name == "insert":
self.client.upload_points(mt_collection, points)
return None
else: # upsert
return self.client.upsert(mt_collection, points)
else:
# Not a known error we can handle, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unhandled Qdrant error: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, re-raise
raise
def insert(self, collection_name: str, items: list[VectorItem]):
"""
Insert items with tenant ID.
"""
if not self.client or not items:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Get dimensions from the actual vectors
dimension = len(items[0]["vector"]) if items else None
# Create points with tenant ID
points = self._create_points(items, tenant_id)
# Handle the operation with error retry
return self._handle_operation_with_error_retry(
"insert", mt_collection, points, dimension
)
def upsert(self, collection_name: str, items: list[VectorItem]):
""" """
Upsert items with tenant ID. Upsert items with tenant ID.
""" """
if not self.client or not items: if not self.client or not items:
return None return None
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
dimension = len(items[0]["vector"])
self._ensure_collection(mt_collection, dimension)
points = self._create_points(items, tenant_id)
self.client.upload_points(mt_collection, points)
return None
def insert(self, collection_name: str, items: List[VectorItem]): # Map to multi-tenant collection and tenant ID
""" mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
Insert items with tenant ID.
""" # Get dimensions from the actual vectors
return self.upsert(collection_name, items) dimension = len(items[0]["vector"]) if items else None
# Create points with tenant ID
points = self._create_points(items, tenant_id)
# Handle the operation with error retry
return self._handle_operation_with_error_retry(
"upsert", mt_collection, points, dimension
)
def reset(self): def reset(self):
""" """
@ -348,9 +677,11 @@ class QdrantClient(VectorDBBase):
""" """
if not self.client: if not self.client:
return None return None
for collection in self.client.get_collections().collections:
if collection.name.startswith(self.collection_prefix): collection_names = self.client.get_collections().collections
self.client.delete_collection(collection_name=collection.name) for collection_name in collection_names:
if collection_name.name.startswith(self.collection_prefix):
self.client.delete_collection(collection_name=collection_name.name)
def delete_collection(self, collection_name: str): def delete_collection(self, collection_name: str):
""" """
@ -358,13 +689,24 @@ class QdrantClient(VectorDBBase):
""" """
if not self.client: if not self.client:
return None return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection):
log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete") tenant_filter = models.FieldCondition(
return None key="tenant_id", match=models.MatchValue(value=tenant_id)
self.client.delete( )
field_conditions = [tenant_filter]
update_result = self.client.delete(
collection_name=mt_collection, collection_name=mt_collection,
points_selector=models.FilterSelector( points_selector=models.FilterSelector(
filter=models.Filter(must=[_tenant_filter(tenant_id)]) filter=models.Filter(must=field_conditions)
), ),
) )
if self.client.get_collection(mt_collection).points_count == 0:
self.client.delete_collection(mt_collection)
return update_result

View file

@ -1,774 +0,0 @@
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
GetResult,
SearchResult,
)
from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION
from typing import List, Optional, Dict, Any, Union
import logging
import boto3
log = logging.getLogger(__name__)
class S3VectorClient(VectorDBBase):
"""
AWS S3 Vector integration for Open WebUI Knowledge.
"""
def __init__(self):
self.bucket_name = S3_VECTOR_BUCKET_NAME
self.region = S3_VECTOR_REGION
# Simple validation - log warnings instead of raising exceptions
if not self.bucket_name:
log.warning("S3_VECTOR_BUCKET_NAME not set - S3Vector will not work")
if not self.region:
log.warning("S3_VECTOR_REGION not set - S3Vector will not work")
if self.bucket_name and self.region:
try:
self.client = boto3.client("s3vectors", region_name=self.region)
log.info(
f"S3Vector client initialized for bucket '{self.bucket_name}' in region '{self.region}'"
)
except Exception as e:
log.error(f"Failed to initialize S3Vector client: {e}")
self.client = None
else:
self.client = None
def _create_index(
self,
index_name: str,
dimension: int,
data_type: str = "float32",
distance_metric: str = "cosine",
) -> None:
"""
Create a new index in the S3 vector bucket for the given collection if it does not exist.
"""
if self.has_collection(index_name):
log.debug(f"Index '{index_name}' already exists, skipping creation")
return
try:
self.client.create_index(
vectorBucketName=self.bucket_name,
indexName=index_name,
dataType=data_type,
dimension=dimension,
distanceMetric=distance_metric,
)
log.info(
f"Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})"
)
except Exception as e:
log.error(f"Error creating S3 index '{index_name}': {e}")
raise
def _filter_metadata(
self, metadata: Dict[str, Any], item_id: str
) -> Dict[str, Any]:
"""
Filter vector metadata keys to comply with S3 Vector API limit of 10 keys maximum.
"""
if not isinstance(metadata, dict) or len(metadata) <= 10:
return metadata
# Keep only the first 10 keys, prioritizing important ones based on actual Open WebUI metadata
important_keys = [
"text", # The actual document content
"file_id", # File ID
"source", # Document source file
"title", # Document title
"page", # Page number
"total_pages", # Total pages in document
"embedding_config", # Embedding configuration
"created_by", # User who created it
"name", # Document name
"hash", # Content hash
]
filtered_metadata = {}
# First, add important keys if they exist
for key in important_keys:
if key in metadata:
filtered_metadata[key] = metadata[key]
if len(filtered_metadata) >= 10:
break
# If we still have room, add other keys
if len(filtered_metadata) < 10:
for key, value in metadata.items():
if key not in filtered_metadata:
filtered_metadata[key] = value
if len(filtered_metadata) >= 10:
break
log.warning(
f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys"
)
return filtered_metadata
def has_collection(self, collection_name: str) -> bool:
"""
Check if a vector index exists using direct lookup.
This avoids pagination issues with list_indexes() and is significantly faster.
"""
try:
self.client.get_index(
vectorBucketName=self.bucket_name, indexName=collection_name
)
return True
except Exception as e:
log.error(f"Error checking if index '{collection_name}' exists: {e}")
return False
def delete_collection(self, collection_name: str) -> None:
"""
Delete an entire S3 Vector index/collection.
"""
if not self.has_collection(collection_name):
log.warning(
f"Collection '{collection_name}' does not exist, nothing to delete"
)
return
try:
log.info(f"Deleting collection '{collection_name}'")
self.client.delete_index(
vectorBucketName=self.bucket_name, indexName=collection_name
)
log.info(f"Successfully deleted collection '{collection_name}'")
except Exception as e:
log.error(f"Error deleting collection '{collection_name}': {e}")
raise
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
"""
Insert vector items into the S3 Vector index. Create index if it does not exist.
"""
if not items:
log.warning("No items to insert")
return
dimension = len(items[0]["vector"])
try:
if not self.has_collection(collection_name):
log.info(f"Index '{collection_name}' does not exist. Creating index.")
self._create_index(
index_name=collection_name,
dimension=dimension,
data_type="float32",
distance_metric="cosine",
)
# Prepare vectors for insertion
vectors = []
for item in items:
# Ensure vector data is in the correct format for S3 Vector API
vector_data = item["vector"]
if isinstance(vector_data, list):
# Convert list to float32 values as required by S3 Vector API
vector_data = [float(x) for x in vector_data]
# Prepare metadata, ensuring the text field is preserved
metadata = item.get("metadata", {}).copy()
# Add the text field to metadata so it's available for retrieval
metadata["text"] = item["text"]
# Convert metadata to string format for consistency
metadata = process_metadata(metadata)
# Filter metadata to comply with S3 Vector API limit of 10 keys
metadata = self._filter_metadata(metadata, item["id"])
vectors.append(
{
"key": item["id"],
"data": {"float32": vector_data},
"metadata": metadata,
}
)
# Insert vectors in batches of 500 (S3 Vector API limit)
batch_size = 500
for i in range(0, len(vectors), batch_size):
batch = vectors[i : i + batch_size]
self.client.put_vectors(
vectorBucketName=self.bucket_name,
indexName=collection_name,
vectors=batch,
)
log.info(
f"Inserted batch {i//batch_size + 1}: {len(batch)} vectors into index '{collection_name}'."
)
log.info(
f"Completed insertion of {len(vectors)} vectors into index '{collection_name}'."
)
except Exception as e:
log.error(f"Error inserting vectors: {e}")
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
"""
Insert or update vector items in the S3 Vector index. Create index if it does not exist.
"""
if not items:
log.warning("No items to upsert")
return
dimension = len(items[0]["vector"])
log.info(f"Upsert dimension: {dimension}")
try:
if not self.has_collection(collection_name):
log.info(
f"Index '{collection_name}' does not exist. Creating index for upsert."
)
self._create_index(
index_name=collection_name,
dimension=dimension,
data_type="float32",
distance_metric="cosine",
)
# Prepare vectors for upsert
vectors = []
for item in items:
# Ensure vector data is in the correct format for S3 Vector API
vector_data = item["vector"]
if isinstance(vector_data, list):
# Convert list to float32 values as required by S3 Vector API
vector_data = [float(x) for x in vector_data]
# Prepare metadata, ensuring the text field is preserved
metadata = item.get("metadata", {}).copy()
# Add the text field to metadata so it's available for retrieval
metadata["text"] = item["text"]
# Convert metadata to string format for consistency
metadata = process_metadata(metadata)
# Filter metadata to comply with S3 Vector API limit of 10 keys
metadata = self._filter_metadata(metadata, item["id"])
vectors.append(
{
"key": item["id"],
"data": {"float32": vector_data},
"metadata": metadata,
}
)
# Upsert vectors in batches of 500 (S3 Vector API limit)
batch_size = 500
for i in range(0, len(vectors), batch_size):
batch = vectors[i : i + batch_size]
if i == 0: # Log sample info for first batch only
log.info(
f"Upserting batch 1: {len(batch)} vectors. First vector sample: key={batch[0]['key']}, data_type={type(batch[0]['data']['float32'])}, data_len={len(batch[0]['data']['float32'])}"
)
else:
log.info(
f"Upserting batch {i//batch_size + 1}: {len(batch)} vectors."
)
self.client.put_vectors(
vectorBucketName=self.bucket_name,
indexName=collection_name,
vectors=batch,
)
log.info(
f"Completed upsert of {len(vectors)} vectors into index '{collection_name}'."
)
except Exception as e:
log.error(f"Error upserting vectors: {e}")
raise
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
) -> Optional[SearchResult]:
"""
Search for similar vectors in a collection using multiple query vectors.
"""
if not self.has_collection(collection_name):
log.warning(f"Collection '{collection_name}' does not exist")
return None
if not vectors:
log.warning("No query vectors provided")
return None
try:
log.info(
f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}"
)
# Initialize result lists
all_ids = []
all_documents = []
all_metadatas = []
all_distances = []
# Process each query vector
for i, query_vector in enumerate(vectors):
log.debug(f"Processing query vector {i+1}/{len(vectors)}")
# Prepare the query vector in S3 Vector format
query_vector_dict = {"float32": [float(x) for x in query_vector]}
# Call S3 Vector query API
response = self.client.query_vectors(
vectorBucketName=self.bucket_name,
indexName=collection_name,
topK=limit,
queryVector=query_vector_dict,
returnMetadata=True,
returnDistance=True,
)
# Process results for this query
query_ids = []
query_documents = []
query_metadatas = []
query_distances = []
result_vectors = response.get("vectors", [])
for vector in result_vectors:
vector_id = vector.get("key")
vector_metadata = vector.get("metadata", {})
vector_distance = vector.get("distance", 0.0)
# Extract document text from metadata
document_text = ""
if isinstance(vector_metadata, dict):
# Get the text field first (highest priority)
document_text = vector_metadata.get("text")
if not document_text:
# Fallback to other possible text fields
document_text = (
vector_metadata.get("content")
or vector_metadata.get("document")
or vector_id
)
else:
document_text = vector_id
query_ids.append(vector_id)
query_documents.append(document_text)
query_metadatas.append(vector_metadata)
query_distances.append(vector_distance)
# Add this query's results to the overall results
all_ids.append(query_ids)
all_documents.append(query_documents)
all_metadatas.append(query_metadatas)
all_distances.append(query_distances)
log.info(f"Search completed. Found results for {len(all_ids)} queries")
# Return SearchResult format
return SearchResult(
ids=all_ids if all_ids else None,
documents=all_documents if all_documents else None,
metadatas=all_metadatas if all_metadatas else None,
distances=all_distances if all_distances else None,
)
except Exception as e:
log.error(f"Error searching collection '{collection_name}': {str(e)}")
# Handle specific AWS exceptions
if hasattr(e, "response") and "Error" in e.response:
error_code = e.response["Error"]["Code"]
if error_code == "NotFoundException":
log.warning(f"Collection '{collection_name}' not found")
return None
elif error_code == "ValidationException":
log.error(f"Invalid query vector dimensions or parameters")
return None
elif error_code == "AccessDeniedException":
log.error(
f"Access denied for collection '{collection_name}'. Check permissions."
)
return None
raise
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
"""
Query vectors from a collection using metadata filter.
"""
if not self.has_collection(collection_name):
log.warning(f"Collection '{collection_name}' does not exist")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
if not filter:
log.warning("No filter provided, returning all vectors")
return self.get(collection_name)
try:
log.info(f"Querying collection '{collection_name}' with filter: {filter}")
# For S3 Vector, we need to use list_vectors and then filter results
# Since S3 Vector may not support complex server-side filtering,
# we'll retrieve all vectors and filter client-side
# Get all vectors first
all_vectors_result = self.get(collection_name)
if not all_vectors_result or not all_vectors_result.ids:
log.warning("No vectors found in collection")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
# Extract the lists from the result
all_ids = all_vectors_result.ids[0] if all_vectors_result.ids else []
all_documents = (
all_vectors_result.documents[0] if all_vectors_result.documents else []
)
all_metadatas = (
all_vectors_result.metadatas[0] if all_vectors_result.metadatas else []
)
# Apply client-side filtering
filtered_ids = []
filtered_documents = []
filtered_metadatas = []
for i, metadata in enumerate(all_metadatas):
if self._matches_filter(metadata, filter):
if i < len(all_ids):
filtered_ids.append(all_ids[i])
if i < len(all_documents):
filtered_documents.append(all_documents[i])
filtered_metadatas.append(metadata)
# Apply limit if specified
if limit and len(filtered_ids) >= limit:
break
log.info(
f"Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total"
)
# Return GetResult format
if filtered_ids:
return GetResult(
ids=[filtered_ids],
documents=[filtered_documents],
metadatas=[filtered_metadatas],
)
else:
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
except Exception as e:
log.error(f"Error querying collection '{collection_name}': {str(e)}")
# Handle specific AWS exceptions
if hasattr(e, "response") and "Error" in e.response:
error_code = e.response["Error"]["Code"]
if error_code == "NotFoundException":
log.warning(f"Collection '{collection_name}' not found")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
elif error_code == "AccessDeniedException":
log.error(
f"Access denied for collection '{collection_name}'. Check permissions."
)
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
raise
def get(self, collection_name: str) -> Optional[GetResult]:
"""
Retrieve all vectors from a collection.
"""
if not self.has_collection(collection_name):
log.warning(f"Collection '{collection_name}' does not exist")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
try:
log.info(f"Retrieving all vectors from collection '{collection_name}'")
# Initialize result lists
all_ids = []
all_documents = []
all_metadatas = []
# Handle pagination
next_token = None
while True:
# Prepare request parameters
request_params = {
"vectorBucketName": self.bucket_name,
"indexName": collection_name,
"returnData": False, # Don't include vector data (not needed for get)
"returnMetadata": True, # Include metadata
"maxResults": 500, # Use reasonable page size
}
if next_token:
request_params["nextToken"] = next_token
# Call S3 Vector API
response = self.client.list_vectors(**request_params)
# Process vectors in this page
vectors = response.get("vectors", [])
for vector in vectors:
vector_id = vector.get("key")
vector_data = vector.get("data", {})
vector_metadata = vector.get("metadata", {})
# Extract the actual vector array
vector_array = vector_data.get("float32", [])
# For documents, we try to extract text from metadata or use the vector ID
document_text = ""
if isinstance(vector_metadata, dict):
# Get the text field first (highest priority)
document_text = vector_metadata.get("text")
if not document_text:
# Fallback to other possible text fields
document_text = (
vector_metadata.get("content")
or vector_metadata.get("document")
or vector_id
)
# Log the actual content for debugging
log.debug(
f"Document text preview (first 200 chars): {str(document_text)[:200]}"
)
else:
document_text = vector_id
all_ids.append(vector_id)
all_documents.append(document_text)
all_metadatas.append(vector_metadata)
# Check if there are more pages
next_token = response.get("nextToken")
if not next_token:
break
log.info(
f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'"
)
# Return in GetResult format
# The Open WebUI GetResult expects lists of lists, so we wrap each list
if all_ids:
return GetResult(
ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas]
)
else:
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
except Exception as e:
log.error(
f"Error retrieving vectors from collection '{collection_name}': {str(e)}"
)
# Handle specific AWS exceptions
if hasattr(e, "response") and "Error" in e.response:
error_code = e.response["Error"]["Code"]
if error_code == "NotFoundException":
log.warning(f"Collection '{collection_name}' not found")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
elif error_code == "AccessDeniedException":
log.error(
f"Access denied for collection '{collection_name}'. Check permissions."
)
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
raise
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict] = None,
) -> None:
"""
Delete vectors by ID or filter from a collection.
"""
if not self.has_collection(collection_name):
log.warning(
f"Collection '{collection_name}' does not exist, nothing to delete"
)
return
# Check if this is a knowledge collection (not file-specific)
is_knowledge_collection = not collection_name.startswith("file-")
try:
if ids:
# Delete by specific vector IDs/keys
log.info(
f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'"
)
self.client.delete_vectors(
vectorBucketName=self.bucket_name,
indexName=collection_name,
keys=ids,
)
log.info(f"Deleted {len(ids)} vectors from index '{collection_name}'")
elif filter:
# Handle filter-based deletion
log.info(
f"Deleting vectors by filter from collection '{collection_name}': {filter}"
)
# If this is a knowledge collection and we have a file_id filter,
# also clean up the corresponding file-specific collection
if is_knowledge_collection and "file_id" in filter:
file_id = filter["file_id"]
file_collection_name = f"file-{file_id}"
if self.has_collection(file_collection_name):
log.info(
f"Found related file-specific collection '{file_collection_name}', deleting it to prevent duplicates"
)
self.delete_collection(file_collection_name)
# For the main collection, implement query-then-delete
# First, query to get IDs matching the filter
query_result = self.query(collection_name, filter)
if query_result and query_result.ids and query_result.ids[0]:
matching_ids = query_result.ids[0]
log.info(
f"Found {len(matching_ids)} vectors matching filter, deleting them"
)
# Delete the matching vectors by ID
self.client.delete_vectors(
vectorBucketName=self.bucket_name,
indexName=collection_name,
keys=matching_ids,
)
log.info(
f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter"
)
else:
log.warning("No vectors found matching the filter criteria")
else:
log.warning("No IDs or filter provided for deletion")
except Exception as e:
log.error(
f"Error deleting vectors from collection '{collection_name}': {e}"
)
raise
def reset(self) -> None:
"""
Reset/clear all vector data. For S3 Vector, this deletes all indexes.
"""
try:
log.warning(
"Reset called - this will delete all vector indexes in the S3 bucket"
)
# List all indexes
response = self.client.list_indexes(vectorBucketName=self.bucket_name)
indexes = response.get("indexes", [])
if not indexes:
log.warning("No indexes found to delete")
return
# Delete all indexes
deleted_count = 0
for index in indexes:
index_name = index.get("indexName")
if index_name:
try:
self.client.delete_index(
vectorBucketName=self.bucket_name, indexName=index_name
)
deleted_count += 1
log.info(f"Deleted index: {index_name}")
except Exception as e:
log.error(f"Error deleting index '{index_name}': {e}")
log.info(f"Reset completed: deleted {deleted_count} indexes")
except Exception as e:
log.error(f"Error during reset: {e}")
raise
def _matches_filter(self, metadata: Dict[str, Any], filter: Dict[str, Any]) -> bool:
"""
Check if metadata matches the given filter conditions.
"""
if not isinstance(metadata, dict) or not isinstance(filter, dict):
return False
# Check each filter condition
for key, expected_value in filter.items():
# Handle special operators
if key.startswith("$"):
if key == "$and":
# All conditions must match
if not isinstance(expected_value, list):
continue
for condition in expected_value:
if not self._matches_filter(metadata, condition):
return False
elif key == "$or":
# At least one condition must match
if not isinstance(expected_value, list):
continue
any_match = False
for condition in expected_value:
if self._matches_filter(metadata, condition):
any_match = True
break
if not any_match:
return False
continue
# Get the actual value from metadata
actual_value = metadata.get(key)
# Handle different types of expected values
if isinstance(expected_value, dict):
# Handle comparison operators
for op, op_value in expected_value.items():
if op == "$eq":
if actual_value != op_value:
return False
elif op == "$ne":
if actual_value == op_value:
return False
elif op == "$in":
if (
not isinstance(op_value, list)
or actual_value not in op_value
):
return False
elif op == "$nin":
if isinstance(op_value, list) and actual_value in op_value:
return False
elif op == "$exists":
if bool(op_value) != (key in metadata):
return False
# Add more operators as needed
else:
# Simple equality check
if actual_value != expected_value:
return False
return True

View file

@ -1,340 +0,0 @@
import weaviate
import re
import uuid
from typing import Any, Dict, List, Optional, Union
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.config import (
WEAVIATE_HTTP_HOST,
WEAVIATE_HTTP_PORT,
WEAVIATE_GRPC_PORT,
WEAVIATE_API_KEY,
)
def _convert_uuids_to_strings(obj: Any) -> Any:
"""
Recursively convert UUID objects to strings in nested data structures.
This function handles:
- UUID objects -> string
- Dictionaries with UUID values
- Lists/Tuples with UUID values
- Nested combinations of the above
Args:
obj: Any object that might contain UUIDs
Returns:
The same object structure with UUIDs converted to strings
"""
if isinstance(obj, uuid.UUID):
return str(obj)
elif isinstance(obj, dict):
return {key: _convert_uuids_to_strings(value) for key, value in obj.items()}
elif isinstance(obj, (list, tuple)):
return type(obj)(_convert_uuids_to_strings(item) for item in obj)
elif isinstance(obj, (str, int, float, bool, type(None))):
return obj
else:
return obj
class WeaviateClient(VectorDBBase):
def __init__(self):
self.url = WEAVIATE_HTTP_HOST
try:
# Build connection parameters
connection_params = {
"host": WEAVIATE_HTTP_HOST,
"port": WEAVIATE_HTTP_PORT,
"grpc_port": WEAVIATE_GRPC_PORT,
}
# Only add auth_credentials if WEAVIATE_API_KEY exists and is not empty
if WEAVIATE_API_KEY:
connection_params["auth_credentials"] = (
weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY)
)
self.client = weaviate.connect_to_local(**connection_params)
self.client.connect()
except Exception as e:
raise ConnectionError(f"Failed to connect to Weaviate: {e}") from e
def _sanitize_collection_name(self, collection_name: str) -> str:
"""Sanitize collection name to be a valid Weaviate class name."""
if not isinstance(collection_name, str) or not collection_name.strip():
raise ValueError("Collection name must be a non-empty string")
# Requirements for a valid Weaviate class name:
# The collection name must begin with a capital letter.
# The name can only contain letters, numbers, and the underscore (_) character. Spaces are not allowed.
# Replace hyphens with underscores and keep only alphanumeric characters
name = re.sub(r"[^a-zA-Z0-9_]", "", collection_name.replace("-", "_"))
name = name.strip("_")
if not name:
raise ValueError(
"Could not sanitize collection name to be a valid Weaviate class name"
)
# Ensure it starts with a letter and is capitalized
if not name[0].isalpha():
name = "C" + name
return name[0].upper() + name[1:]
def has_collection(self, collection_name: str) -> bool:
sane_collection_name = self._sanitize_collection_name(collection_name)
return self.client.collections.exists(sane_collection_name)
def delete_collection(self, collection_name: str) -> None:
sane_collection_name = self._sanitize_collection_name(collection_name)
if self.client.collections.exists(sane_collection_name):
self.client.collections.delete(sane_collection_name)
def _create_collection(self, collection_name: str) -> None:
self.client.collections.create(
name=collection_name,
vector_config=weaviate.classes.config.Configure.Vectors.self_provided(),
properties=[
weaviate.classes.config.Property(
name="text", data_type=weaviate.classes.config.DataType.TEXT
),
],
)
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
self._create_collection(sane_collection_name)
collection = self.client.collections.get(sane_collection_name)
with collection.batch.fixed_size(batch_size=100) as batch:
for item in items:
item_uuid = str(uuid.uuid4()) if not item["id"] else str(item["id"])
properties = {"text": item["text"]}
if item["metadata"]:
clean_metadata = _convert_uuids_to_strings(
process_metadata(item["metadata"])
)
clean_metadata.pop("text", None)
properties.update(clean_metadata)
batch.add_object(
properties=properties, uuid=item_uuid, vector=item["vector"]
)
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
self._create_collection(sane_collection_name)
collection = self.client.collections.get(sane_collection_name)
with collection.batch.fixed_size(batch_size=100) as batch:
for item in items:
item_uuid = str(item["id"]) if item["id"] else None
properties = {"text": item["text"]}
if item["metadata"]:
clean_metadata = _convert_uuids_to_strings(
process_metadata(item["metadata"])
)
clean_metadata.pop("text", None)
properties.update(clean_metadata)
batch.add_object(
properties=properties, uuid=item_uuid, vector=item["vector"]
)
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
) -> Optional[SearchResult]:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
return None
collection = self.client.collections.get(sane_collection_name)
result_ids, result_documents, result_metadatas, result_distances = (
[],
[],
[],
[],
)
for vector_embedding in vectors:
try:
response = collection.query.near_vector(
near_vector=vector_embedding,
limit=limit,
return_metadata=weaviate.classes.query.MetadataQuery(distance=True),
)
ids = [str(obj.uuid) for obj in response.objects]
documents = []
metadatas = []
distances = []
for obj in response.objects:
properties = dict(obj.properties) if obj.properties else {}
documents.append(properties.pop("text", ""))
metadatas.append(_convert_uuids_to_strings(properties))
# Weaviate has cosine distance, 2 (worst) -> 0 (best). Re-ordering to 0 -> 1
raw_distances = [
(
obj.metadata.distance
if obj.metadata and obj.metadata.distance
else 2.0
)
for obj in response.objects
]
distances = [(2 - dist) / 2 for dist in raw_distances]
result_ids.append(ids)
result_documents.append(documents)
result_metadatas.append(metadatas)
result_distances.append(distances)
except Exception:
result_ids.append([])
result_documents.append([])
result_metadatas.append([])
result_distances.append([])
return SearchResult(
**{
"ids": result_ids,
"documents": result_documents,
"metadatas": result_metadatas,
"distances": result_distances,
}
)
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
return None
collection = self.client.collections.get(sane_collection_name)
weaviate_filter = None
if filter:
for key, value in filter.items():
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(
value
)
weaviate_filter = (
prop_filter
if weaviate_filter is None
else weaviate.classes.query.Filter.all_of(
[weaviate_filter, prop_filter]
)
)
try:
response = collection.query.fetch_objects(
filters=weaviate_filter, limit=limit
)
ids = [str(obj.uuid) for obj in response.objects]
documents = []
metadatas = []
for obj in response.objects:
properties = dict(obj.properties) if obj.properties else {}
documents.append(properties.pop("text", ""))
metadatas.append(_convert_uuids_to_strings(properties))
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
except Exception:
return None
def get(self, collection_name: str) -> Optional[GetResult]:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
return None
collection = self.client.collections.get(sane_collection_name)
ids, documents, metadatas = [], [], []
try:
for item in collection.iterator():
ids.append(str(item.uuid))
properties = dict(item.properties) if item.properties else {}
documents.append(properties.pop("text", ""))
metadatas.append(_convert_uuids_to_strings(properties))
if not ids:
return None
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
except Exception:
return None
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict] = None,
) -> None:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
return
collection = self.client.collections.get(sane_collection_name)
try:
if ids:
for item_id in ids:
collection.data.delete_by_id(uuid=item_id)
elif filter:
weaviate_filter = None
for key, value in filter.items():
prop_filter = weaviate.classes.query.Filter.by_property(
name=key
).equal(value)
weaviate_filter = (
prop_filter
if weaviate_filter is None
else weaviate.classes.query.Filter.all_of(
[weaviate_filter, prop_filter]
)
)
if weaviate_filter:
collection.data.delete_many(where=weaviate_filter)
except Exception:
pass
def reset(self) -> None:
try:
for collection_name in self.client.collections.list_all().keys():
self.client.collections.delete(collection_name)
except Exception:
pass

View file

@ -1,10 +1,6 @@
from open_webui.retrieval.vector.main import VectorDBBase from open_webui.retrieval.vector.main import VectorDBBase
from open_webui.retrieval.vector.type import VectorType from open_webui.retrieval.vector.type import VectorType
from open_webui.config import ( from open_webui.config import VECTOR_DB, ENABLE_QDRANT_MULTITENANCY_MODE
VECTOR_DB,
ENABLE_QDRANT_MULTITENANCY_MODE,
ENABLE_MILVUS_MULTITENANCY_MODE,
)
class Vector: class Vector:
@ -16,16 +12,9 @@ class Vector:
""" """
match vector_type: match vector_type:
case VectorType.MILVUS: case VectorType.MILVUS:
if ENABLE_MILVUS_MULTITENANCY_MODE: from open_webui.retrieval.vector.dbs.milvus import MilvusClient
from open_webui.retrieval.vector.dbs.milvus_multitenancy import (
MilvusClient,
)
return MilvusClient() return MilvusClient()
else:
from open_webui.retrieval.vector.dbs.milvus import MilvusClient
return MilvusClient()
case VectorType.QDRANT: case VectorType.QDRANT:
if ENABLE_QDRANT_MULTITENANCY_MODE: if ENABLE_QDRANT_MULTITENANCY_MODE:
from open_webui.retrieval.vector.dbs.qdrant_multitenancy import ( from open_webui.retrieval.vector.dbs.qdrant_multitenancy import (
@ -41,10 +30,6 @@ class Vector:
from open_webui.retrieval.vector.dbs.pinecone import PineconeClient from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
return PineconeClient() return PineconeClient()
case VectorType.S3VECTOR:
from open_webui.retrieval.vector.dbs.s3vector import S3VectorClient
return S3VectorClient()
case VectorType.OPENSEARCH: case VectorType.OPENSEARCH:
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
@ -63,14 +48,6 @@ class Vector:
from open_webui.retrieval.vector.dbs.chroma import ChromaClient from open_webui.retrieval.vector.dbs.chroma import ChromaClient
return ChromaClient() return ChromaClient()
case VectorType.ORACLE23AI:
from open_webui.retrieval.vector.dbs.oracle23ai import Oracle23aiClient
return Oracle23aiClient()
case VectorType.WEAVIATE:
from open_webui.retrieval.vector.dbs.weaviate import WeaviateClient
return WeaviateClient()
case _: case _:
raise ValueError(f"Unsupported vector type: {vector_type}") raise ValueError(f"Unsupported vector type: {vector_type}")

View file

@ -9,6 +9,3 @@ class VectorType(StrEnum):
ELASTICSEARCH = "elasticsearch" ELASTICSEARCH = "elasticsearch"
OPENSEARCH = "opensearch" OPENSEARCH = "opensearch"
PGVECTOR = "pgvector" PGVECTOR = "pgvector"
ORACLE23AI = "oracle23ai"
S3VECTOR = "s3vector"
WEAVIATE = "weaviate"

View file

@ -1,28 +0,0 @@
from datetime import datetime
KEYS_TO_EXCLUDE = ["content", "pages", "tables", "paragraphs", "sections", "figures"]
def filter_metadata(metadata: dict[str, any]) -> dict[str, any]:
metadata = {
key: value for key, value in metadata.items() if key not in KEYS_TO_EXCLUDE
}
return metadata
def process_metadata(
metadata: dict[str, any],
) -> dict[str, any]:
for key, value in metadata.items():
# Remove large fields
if key in KEYS_TO_EXCLUDE:
del metadata[key]
# Convert non-serializable fields to strings
if (
isinstance(value, datetime)
or isinstance(value, list)
or isinstance(value, dict)
):
metadata[key] = str(value)
return metadata

View file

@ -1,126 +0,0 @@
import logging
from typing import Optional
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
log = logging.getLogger(__name__)
"""
Azure AI Search integration for Open WebUI.
Documentation: https://learn.microsoft.com/en-us/python/api/overview/azure/search-documents-readme?view=azure-python
Required package: azure-search-documents
Install: pip install azure-search-documents
"""
def search_azure(
api_key: str,
endpoint: str,
index_name: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""
Search using Azure AI Search.
Args:
api_key: Azure Search API key (query key or admin key)
endpoint: Azure Search service endpoint (e.g., https://myservice.search.windows.net)
index_name: Name of the search index to query
query: Search query string
count: Number of results to return
filter_list: Optional list of domains to filter results
Returns:
List of SearchResult objects with link, title, and snippet
"""
try:
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
except ImportError:
log.error(
"azure-search-documents package is not installed. "
"Install it with: pip install azure-search-documents"
)
raise ImportError(
"azure-search-documents is required for Azure AI Search. "
"Install it with: pip install azure-search-documents"
)
try:
# Create search client with API key authentication
credential = AzureKeyCredential(api_key)
search_client = SearchClient(
endpoint=endpoint, index_name=index_name, credential=credential
)
# Perform the search
results = search_client.search(search_text=query, top=count)
# Convert results to list and extract fields
search_results = []
for result in results:
# Azure AI Search returns documents with custom schemas
# We need to extract common fields that might represent URL, title, and content
# Common field names to look for:
result_dict = dict(result)
# Try to find URL field (common names)
link = (
result_dict.get("url")
or result_dict.get("link")
or result_dict.get("uri")
or result_dict.get("metadata_storage_path")
or ""
)
# Try to find title field (common names)
title = (
result_dict.get("title")
or result_dict.get("name")
or result_dict.get("metadata_title")
or result_dict.get("metadata_storage_name")
or None
)
# Try to find content/snippet field (common names)
snippet = (
result_dict.get("content")
or result_dict.get("snippet")
or result_dict.get("description")
or result_dict.get("summary")
or result_dict.get("text")
or None
)
# Truncate snippet if too long
if snippet and len(snippet) > 500:
snippet = snippet[:497] + "..."
if link: # Only add if we found a valid link
search_results.append(
{
"link": link,
"title": title,
"snippet": snippet,
}
)
# Apply domain filtering if specified
if filter_list:
search_results = get_filtered_results(search_results, filter_list)
# Convert to SearchResult objects
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
)
for result in search_results
]
except Exception as ex:
log.error(f"Azure AI Search error: {ex}")
raise ex

View file

@ -4,9 +4,11 @@ from pprint import pprint
from typing import Optional from typing import Optional
import requests import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
import argparse import argparse
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
""" """
Documentation: https://docs.microsoft.com/en-us/bing/search-apis/bing-web-search/overview Documentation: https://docs.microsoft.com/en-us/bing/search-apis/bing-web-search/overview
""" """

View file

@ -4,18 +4,20 @@ from typing import Optional
import requests import requests
import json import json
from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def _parse_response(response): def _parse_response(response):
results = [] result = {}
if "data" in response: if "data" in response:
data = response["data"] data = response["data"]
if "webPages" in data: if "webPages" in data:
webPages = data["webPages"] webPages = data["webPages"]
if "value" in webPages: if "value" in webPages:
results = [ result["webpage"] = [
{ {
"id": item.get("id", ""), "id": item.get("id", ""),
"name": item.get("name", ""), "name": item.get("name", ""),
@ -29,7 +31,7 @@ def _parse_response(response):
} }
for item in webPages["value"] for item in webPages["value"]
] ]
return results return result
def search_bocha( def search_bocha(
@ -51,7 +53,7 @@ def search_bocha(
response = requests.post(url, headers=headers, data=payload, timeout=5) response = requests.post(url, headers=headers, data=payload, timeout=5)
response.raise_for_status() response.raise_for_status()
results = _parse_response(response.json()) results = _parse_response(response.json())
print(results)
if filter_list: if filter_list:
results = get_filtered_results(results, filter_list) results = get_filtered_results(results, filter_list)
@ -59,5 +61,5 @@ def search_bocha(
SearchResult( SearchResult(
link=result["url"], title=result.get("name"), snippet=result.get("summary") link=result["url"], title=result.get("name"), snippet=result.get("summary")
) )
for result in results[:count] for result in results.get("webpage", [])[:count]
] ]

View file

@ -3,8 +3,10 @@ from typing import Optional
import requests import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_brave( def search_brave(
@ -34,9 +36,7 @@ def search_brave(
return [ return [
SearchResult( SearchResult(
link=result["url"], link=result["url"], title=result.get("title"), snippet=result.get("snippet")
title=result.get("title"),
snippet=result.get("description"),
) )
for result in results[:count] for result in results[:count]
] ]

View file

@ -2,17 +2,16 @@ import logging
from typing import Optional from typing import Optional
from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from ddgs import DDGS from duckduckgo_search import DDGS
from ddgs.exceptions import RatelimitException from duckduckgo_search.exceptions import RatelimitException
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_duckduckgo( def search_duckduckgo(
query: str, query: str, count: int, filter_list: Optional[list[str]] = None
count: int,
filter_list: Optional[list[str]] = None,
concurrent_requests: Optional[int] = None,
) -> list[SearchResult]: ) -> list[SearchResult]:
""" """
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects. Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
@ -26,9 +25,6 @@ def search_duckduckgo(
# Use the DDGS context manager to create a DDGS object # Use the DDGS context manager to create a DDGS object
search_results = [] search_results = []
with DDGS() as ddgs: with DDGS() as ddgs:
if concurrent_requests:
ddgs.threads = concurrent_requests
# Use the ddgs.text() method to perform the search # Use the ddgs.text() method to perform the search
try: try:
search_results = ddgs.text( search_results = ddgs.text(

View file

@ -3,9 +3,11 @@ from dataclasses import dataclass
from typing import Optional from typing import Optional
import requests import requests
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.web.main import SearchResult from open_webui.retrieval.web.main import SearchResult
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
EXA_API_BASE = "https://api.exa.ai" EXA_API_BASE = "https://api.exa.ai"

View file

@ -2,40 +2,27 @@ import logging
from typing import Optional, List from typing import Optional, List
import requests import requests
from fastapi import Request
from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.utils.headers import include_user_info_headers from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_external( def search_external(
request: Request,
external_url: str, external_url: str,
external_api_key: str, external_api_key: str,
query: str, query: str,
count: int, count: int,
filter_list: Optional[List[str]] = None, filter_list: Optional[List[str]] = None,
user=None,
) -> List[SearchResult]: ) -> List[SearchResult]:
try: try:
headers = {
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Authorization": f"Bearer {external_api_key}",
}
headers = include_user_info_headers(headers, user)
chat_id = getattr(request.state, "chat_id", None)
if chat_id:
headers["X-OpenWebUI-Chat-Id"] = str(chat_id)
response = requests.post( response = requests.post(
external_url, external_url,
headers=headers, headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Authorization": f"Bearer {external_api_key}",
},
json={ json={
"query": query, "query": query,
"count": count, "count": count,

View file

@ -1,10 +1,13 @@
import logging import logging
from typing import Optional, List from typing import Optional, List
from urllib.parse import urljoin
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_firecrawl( def search_firecrawl(
@ -15,20 +18,27 @@ def search_firecrawl(
filter_list: Optional[List[str]] = None, filter_list: Optional[List[str]] = None,
) -> List[SearchResult]: ) -> List[SearchResult]:
try: try:
from firecrawl import FirecrawlApp firecrawl_search_url = urljoin(firecrawl_url, "/v1/search")
response = requests.post(
firecrawl = FirecrawlApp(api_key=firecrawl_api_key, api_url=firecrawl_url) firecrawl_search_url,
response = firecrawl.search( headers={
query=query, limit=count, ignore_invalid_urls=True, timeout=count * 3 "User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Authorization": f"Bearer {firecrawl_api_key}",
},
json={
"query": query,
"limit": count,
},
) )
results = response.web response.raise_for_status()
results = response.json().get("data", [])
if filter_list: if filter_list:
results = get_filtered_results(results, filter_list) results = get_filtered_results(results, filter_list)
results = [ results = [
SearchResult( SearchResult(
link=result.url, link=result.get("url"),
title=result.title, title=result.get("title"),
snippet=result.description, snippet=result.get("description"),
) )
for result in results[:count] for result in results[:count]
] ]

View file

@ -3,8 +3,10 @@ from typing import Optional
import requests import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_google_pse( def search_google_pse(
@ -13,7 +15,6 @@ def search_google_pse(
query: str, query: str,
count: int, count: int,
filter_list: Optional[list[str]] = None, filter_list: Optional[list[str]] = None,
referer: Optional[str] = None,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects. """Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
Handles pagination for counts greater than 10. Handles pagination for counts greater than 10.
@ -29,11 +30,7 @@ def search_google_pse(
list[SearchResult]: A list of SearchResult objects. list[SearchResult]: A list of SearchResult objects.
""" """
url = "https://www.googleapis.com/customsearch/v1" url = "https://www.googleapis.com/customsearch/v1"
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
if referer:
headers["Referer"] = referer
all_results = [] all_results = []
start_index = 1 # Google PSE start parameter is 1-based start_index = 1 # Google PSE start parameter is 1-based

View file

@ -2,9 +2,11 @@ import logging
import requests import requests
from open_webui.retrieval.web.main import SearchResult from open_webui.retrieval.web.main import SearchResult
from open_webui.env import SRC_LOG_LEVELS
from yarl import URL from yarl import URL
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]: def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:

View file

@ -3,8 +3,10 @@ from typing import Optional
import requests import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_kagi( def search_kagi(

Some files were not shown because too many files have changed in this diff Show more