Compare commits

..

No commits in common. "main" and "v0.6.12" have entirely different histories.

733 changed files with 42501 additions and 123761 deletions

View file

@ -7,15 +7,6 @@ OPENAI_API_KEY=''
# AUTOMATIC1111_BASE_URL="http://localhost:7860"
# For production, you should only need one host as
# fastapi serves the svelte-kit built frontend and backend from the same host and port.
# To test with CORS locally, you can set something like
# CORS_ALLOW_ORIGIN='http://localhost:5173;http://localhost:8080'
CORS_ALLOW_ORIGIN='*'
# For production you should set this to match the proxy configuration (127.0.0.1)
FORWARDED_ALLOW_IPS='*'
# DO NOT TRACK
SCARF_NO_ANALYTICS=true
DO_NOT_TRACK=true

50
.gitattributes vendored
View file

@ -1,49 +1 @@
# TypeScript
*.ts text eol=lf
*.tsx text eol=lf
# JavaScript
*.js text eol=lf
*.jsx text eol=lf
*.mjs text eol=lf
*.cjs text eol=lf
# Svelte
*.svelte text eol=lf
# HTML/CSS
*.html text eol=lf
*.css text eol=lf
*.scss text eol=lf
*.less text eol=lf
# Config files and JSON
*.json text eol=lf
*.jsonc text eol=lf
*.yml text eol=lf
*.yaml text eol=lf
*.toml text eol=lf
# Shell scripts
*.sh text eol=lf
# Markdown & docs
*.md text eol=lf
*.mdx text eol=lf
*.txt text eol=lf
# Git-related
.gitattributes text eol=lf
.gitignore text eol=lf
# Prettier and other dotfiles
.prettierrc text eol=lf
.prettierignore text eol=lf
.eslintrc text eol=lf
.eslintignore text eol=lf
.stylelintrc text eol=lf
.editorconfig text eol=lf
# Misc
*.env text eol=lf
*.lock text eol=lf
*.sh text eol=lf

View file

@ -11,9 +11,7 @@ body:
## Important Notes
- **Before submitting a bug report**: Please check the [Issues](https://github.com/open-webui/open-webui/issues) and [Discussions](https://github.com/open-webui/open-webui/discussions) sections to see if a similar issue has already been reported. If unsure, start a discussion first, as this helps us efficiently focus on improving the project. Duplicates may be closed without notice. **Please search for existing issues AND discussions. No matter open or closed.**
- Check for opened, **but also for (recently) CLOSED issues** as the issue you are trying to report **might already have been fixed on the dev branch!**
- **Before submitting a bug report**: Please check the [Issues](https://github.com/open-webui/open-webui/issues) or [Discussions](https://github.com/open-webui/open-webui/discussions) sections to see if a similar issue has already been reported. If unsure, start a discussion first, as this helps us efficiently focus on improving the project.
- **Respectful collaboration**: Open WebUI is a volunteer-driven project with a single maintainer and contributors who also have full-time jobs. Please be constructive and respectful in your communication.
@ -21,19 +19,13 @@ body:
- **Bug Reproducibility**: If a bug cannot be reproduced using a `:main` or `:dev` Docker setup or with `pip install` on Python 3.11, community assistance may be required. In such cases, we will move it to the "[Issues](https://github.com/open-webui/open-webui/discussions/categories/issues)" Discussions section. Your help is appreciated!
- **Scope**: If you want to report a SECURITY VULNERABILITY, then do so through our [GitHub security page](https://github.com/open-webui/open-webui/security).
- type: checkboxes
id: issue-check
attributes:
label: Check Existing Issues
description: Confirm that youve checked for existing reports before submitting a new one.
options:
- label: I have searched for any existing and/or related issues.
required: true
- label: I have searched for any existing and/or related discussions.
required: true
- label: I have also searched in the CLOSED issues AND CLOSED discussions and found no related items (your issue might already be addressed on the development branch!).
- label: I have searched the existing issues and discussions.
required: true
- label: I am using the latest version of Open WebUI.
required: true
@ -55,7 +47,7 @@ body:
id: open-webui-version
attributes:
label: Open WebUI Version
description: Specify the version (e.g., v0.6.26)
description: Specify the version (e.g., v0.3.11)
validations:
required: true
@ -71,7 +63,7 @@ body:
id: operating-system
attributes:
label: Operating System
description: Specify the OS (e.g., Windows 10, macOS Sonoma, Ubuntu 22.04, Debian 12)
description: Specify the OS (e.g., Windows 10, macOS Sonoma, Ubuntu 22.04)
validations:
required: true
@ -134,7 +126,6 @@ body:
description: |
Please provide a **very detailed, step-by-step guide** to reproduce the issue. Your instructions should be so clear and precise that anyone can follow them without guesswork. Include every relevant detail—settings, configuration options, exact commands used, values entered, and any prerequisites or environment variables.
**If full reproduction steps and all relevant settings are not provided, your issue may not be addressed.**
**If your steps to reproduction are incomplete, lacking detail or not reproducible, your issue can not be addressed.**
placeholder: |
Example (include every detail):
@ -172,5 +163,5 @@ body:
attributes:
value: |
## Note
**If the bug report is incomplete, does not follow instructions or is lacking details it may not be addressed.** Ensure that you've followed all the **README.md** and **troubleshooting.md** guidelines, and provide all necessary information for us to reproduce the issue.
If the bug report is incomplete or does not follow instructions, it may not be addressed. Ensure that you've followed all the **README.md** and **troubleshooting.md** guidelines, and provide all necessary information for us to reproduce the issue.
Thank you for contributing to Open WebUI!

View file

@ -8,21 +8,10 @@ body:
value: |
## Important Notes
### Before submitting
Please check the **open AND closed** [Issues](https://github.com/open-webui/open-webui/issues) AND [Discussions](https://github.com/open-webui/open-webui/discussions) to see if a similar request has been posted.
Please check the [Issues](https://github.com/open-webui/open-webui/issues) or [Discussions](https://github.com/open-webui/open-webui/discussions) to see if a similar request has been posted.
It's likely we're already tracking it! If youre unsure, start a discussion post first.
#### Scope
If your feature request is likely to take more than a quick coding session to implement, test and verify, then open it in the **Ideas** section of the [Discussions](https://github.com/open-webui/open-webui/discussions) instead.
**We will close and force move your feature request to the Ideas section, if we believe your feature request is not trivial/quick to implement.**
This is to ensure the issues tab is used only for issues, quickly addressable feature requests and tracking tickets by the maintainers.
Other feature requests belong in the **Ideas** section of the [Discussions](https://github.com/open-webui/open-webui/discussions).
If your feature request might impact others in the community, definitely open a discussion instead and evaluate whether and how to implement it.
This will help us efficiently focus on improving the project.
### Collaborate respectfully
We value a **constructive attitude**, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. We're here to help if you're **open to learning** and **communicating positively**.
@ -33,6 +22,7 @@ body:
We appreciate your time and ask that you **respect ours**.
### Contributing
If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
@ -45,22 +35,14 @@ body:
label: Check Existing Issues
description: Please confirm that you've checked for existing similar requests
options:
- label: I have searched for all existing **open AND closed** issues and discussions for similar requests. I have found none that is comparable to my request.
required: true
- type: checkboxes
id: feature-scope
attributes:
label: Verify Feature Scope
description: Please confirm the feature's scope is within the described scope
options:
- label: I have read through and understood the scope definition for feature requests in the Issues section. I believe my feature request meets the definition and belongs in the Issues section instead of the Discussions.
- label: I have searched the existing issues and discussions.
required: true
- type: textarea
id: problem-description
attributes:
label: Problem Description
description: Is your feature request related to a problem? Please provide a clear and concise description of what the problem is.
placeholder: "Ex. I'm always frustrated when... / Not related to a problem"
placeholder: "Ex. I'm always frustrated when..."
validations:
required: true
- type: textarea

View file

@ -12,6 +12,12 @@ updates:
interval: monthly
target-branch: 'dev'
- package-ecosystem: npm
directory: '/'
schedule:
interval: monthly
target-branch: 'dev'
- package-ecosystem: 'github-actions'
directory: '/'
schedule:

View file

@ -1,20 +1,17 @@
# Pull Request Checklist
### Note to first-time contributors: Please open a discussion post in [Discussions](https://github.com/open-webui/open-webui/discussions) to discuss your idea/fix with the community before creating a pull request, and describe your changes before submitting a pull request.
This is to ensure large feature PRs are discussed with the community first, before starting work on it. If the community does not want this feature or it is not relevant for Open WebUI as a project, it can be identified in the discussion before working on the feature and submitting the PR.
### Note to first-time contributors: Please open a discussion post in [Discussions](https://github.com/open-webui/open-webui/discussions) and describe your changes before submitting a pull request.
**Before submitting, make sure you've checked the following:**
- [ ] **Target branch:** Verify that the pull request targets the `dev` branch. **Not targeting the `dev` branch will lead to immediate closure of the PR.**
- [ ] **Description:** Provide a concise description of the changes made in this pull request down below.
- [ ] **Target branch:** Please verify that the pull request targets the `dev` branch.
- [ ] **Description:** Provide a concise description of the changes made in this pull request.
- [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description.
- [ ] **Documentation:** If necessary, update relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs) like environment variables, the tutorials, or other documentation sources.
- [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources?
- [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation?
- [ ] **Testing:** Perform manual tests to **verify the implemented fix/feature works as intended AND does not break any other functionality**. Take this as an opportunity to **make screenshots of the feature/fix and include it in the PR description**.
- [ ] **Agentic AI Code:** Confirm this Pull Request is **not written by any AI Agent** or has at least **gone through additional human review AND manual testing**. If any AI Agent is the co-author of this PR, it may lead to immediate closure of the PR.
- [ ] **Testing:** Have you written and run sufficient tests to validate the changes?
- [ ] **Code review:** Have you performed a self-review of your code, addressing any coding standard issues and ensuring adherence to the project's coding standards?
- [ ] **Title Prefix:** To clearly categorize this pull request, prefix the pull request title using one of the following:
- [ ] **Prefix:** To clearly categorize this pull request, prefix the pull request title using one of the following:
- **BREAKING CHANGE**: Significant changes that may affect compatibility
- **build**: Changes that affect the build system or external dependencies
- **ci**: Changes to our continuous integration processes or workflows
@ -76,7 +73,4 @@ This is to ensure large feature PRs are discussed with the community first, befo
### Contributor License Agreement
By submitting this pull request, I confirm that I have read and fully agree to the [Contributor License Agreement (CLA)](https://github.com/open-webui/open-webui/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT), and I am providing my contributions under its terms.
> [!NOTE]
> Deleting the CLA section will lead to immediate closure of your PR and it will not be merged in.
By submitting this pull request, I confirm that I have read and fully agree to the [Contributor License Agreement (CLA)](/CONTRIBUTOR_LICENSE_AGREEMENT), and I am providing my contributions under its terms.

View file

@ -11,7 +11,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v4
- name: Check for changes in package.json
run: |
@ -36,7 +36,7 @@ jobs:
echo "::set-output name=content::$CHANGELOG_ESCAPED"
- name: Create GitHub release
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
@ -61,7 +61,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Trigger Docker build workflow
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
github.rest.actions.createWorkflowDispatch({

View file

@ -27,7 +27,7 @@ jobs:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v4
with:
lfs: true
@ -57,8 +57,7 @@ jobs:
git lfs install
git lfs track "*.ttf"
git lfs track "*.jpg"
rm demo.png
rm banner.png
rm demo.gif
git add .
git commit -m "GitHub deploy: ${{ github.sha }}"
git push --force https://open-webui:${HF_TOKEN}@huggingface.co/spaces/open-webui/open-webui main

View file

@ -14,18 +14,16 @@ env:
jobs:
build-main-image:
runs-on: ${{ matrix.runner }}
runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
permissions:
contents: read
packages: write
strategy:
fail-fast: false
matrix:
include:
- platform: linux/amd64
runner: ubuntu-latest
- platform: linux/arm64
runner: ubuntu-24.04-arm
platform:
- linux/amd64
- linux/arm64
steps:
# GitHub Packages requires the entire repository name to be in lowercase
@ -43,7 +41,7 @@ jobs:
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
@ -113,18 +111,16 @@ jobs:
retention-days: 1
build-cuda-image:
runs-on: ${{ matrix.runner }}
runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
permissions:
contents: read
packages: write
strategy:
fail-fast: false
matrix:
include:
- platform: linux/amd64
runner: ubuntu-latest
- platform: linux/arm64
runner: ubuntu-24.04-arm
platform:
- linux/amd64
- linux/arm64
steps:
# GitHub Packages requires the entire repository name to be in lowercase
@ -141,11 +137,8 @@ jobs:
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Delete huge unnecessary tools folder
run: rm -rf /opt/hostedtoolcache
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
@ -217,125 +210,17 @@ jobs:
if-no-files-found: error
retention-days: 1
build-cuda126-image:
runs-on: ${{ matrix.runner }}
permissions:
contents: read
packages: write
strategy:
fail-fast: false
matrix:
include:
- platform: linux/amd64
runner: ubuntu-latest
- platform: linux/arm64
runner: ubuntu-24.04-arm
steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Delete huge unnecessary tools folder
run: rm -rf /opt/hostedtoolcache
- name: Checkout repository
uses: actions/checkout@v5
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to the Container registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata for Docker images (cuda126 tag)
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=tag
type=sha,prefix=git-
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=cuda126
flavor: |
latest=${{ github.ref == 'refs/heads/main' }}
suffix=-cuda126,onlatest=true
- name: Extract metadata for Docker cache
id: cache-meta
uses: docker/metadata-action@v5
with:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
${{ github.ref_type == 'tag' && 'type=raw,value=main' || '' }}
flavor: |
prefix=cache-cuda126-${{ matrix.platform }}-
latest=false
- name: Build Docker image (cuda126)
uses: docker/build-push-action@v5
id: build
with:
context: .
push: true
platforms: ${{ matrix.platform }}
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }}
cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max
build-args: |
BUILD_HASH=${{ github.sha }}
USE_CUDA=true
USE_CUDA_VER=cu126
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
name: digests-cuda126-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1
build-ollama-image:
runs-on: ${{ matrix.runner }}
runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
permissions:
contents: read
packages: write
strategy:
fail-fast: false
matrix:
include:
- platform: linux/amd64
runner: ubuntu-latest
- platform: linux/arm64
runner: ubuntu-24.04-arm
platform:
- linux/amd64
- linux/arm64
steps:
# GitHub Packages requires the entire repository name to be in lowercase
@ -353,7 +238,7 @@ jobs:
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
@ -425,108 +310,6 @@ jobs:
if-no-files-found: error
retention-days: 1
build-slim-image:
runs-on: ${{ matrix.runner }}
permissions:
contents: read
packages: write
strategy:
fail-fast: false
matrix:
include:
- platform: linux/amd64
runner: ubuntu-latest
- platform: linux/arm64
runner: ubuntu-24.04-arm
steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout repository
uses: actions/checkout@v5
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to the Container registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata for Docker images (slim tag)
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=tag
type=sha,prefix=git-
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=slim
flavor: |
latest=${{ github.ref == 'refs/heads/main' }}
suffix=-slim,onlatest=true
- name: Extract metadata for Docker cache
id: cache-meta
uses: docker/metadata-action@v5
with:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
${{ github.ref_type == 'tag' && 'type=raw,value=main' || '' }}
flavor: |
prefix=cache-slim-${{ matrix.platform }}-
latest=false
- name: Build Docker image (slim)
uses: docker/build-push-action@v5
id: build
with:
context: .
push: true
platforms: ${{ matrix.platform }}
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }}
cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max
build-args: |
BUILD_HASH=${{ github.sha }}
USE_SLIM=true
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
name: digests-slim-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1
merge-main-images:
runs-on: ubuntu-latest
needs: [build-main-image]
@ -541,7 +324,7 @@ jobs:
IMAGE_NAME: '${{ github.repository }}'
- name: Download digests
uses: actions/download-artifact@v5
uses: actions/download-artifact@v4
with:
pattern: digests-main-*
path: /tmp/digests
@ -595,7 +378,7 @@ jobs:
IMAGE_NAME: '${{ github.repository }}'
- name: Download digests
uses: actions/download-artifact@v5
uses: actions/download-artifact@v4
with:
pattern: digests-cuda-*
path: /tmp/digests
@ -637,62 +420,6 @@ jobs:
run: |
docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
merge-cuda126-images:
runs-on: ubuntu-latest
needs: [build-cuda126-image]
steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Download digests
uses: actions/download-artifact@v5
with:
pattern: digests-cuda126-*
path: /tmp/digests
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to the Container registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata for Docker images (default latest tag)
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=tag
type=sha,prefix=git-
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=cuda126
flavor: |
latest=${{ github.ref == 'refs/heads/main' }}
suffix=-cuda126,onlatest=true
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
merge-ollama-images:
runs-on: ubuntu-latest
needs: [build-ollama-image]
@ -707,7 +434,7 @@ jobs:
IMAGE_NAME: '${{ github.repository }}'
- name: Download digests
uses: actions/download-artifact@v5
uses: actions/download-artifact@v4
with:
pattern: digests-ollama-*
path: /tmp/digests
@ -748,59 +475,3 @@ jobs:
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
merge-slim-images:
runs-on: ubuntu-latest
needs: [build-slim-image]
steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Download digests
uses: actions/download-artifact@v5
with:
pattern: digests-slim-*
path: /tmp/digests
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to the Container registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata for Docker images (default slim tag)
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.FULL_IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=tag
type=sha,prefix=git-
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=slim
flavor: |
latest=${{ github.ref == 'refs/heads/main' }}
suffix=-slim,onlatest=true
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}

View file

@ -30,10 +30,10 @@ jobs:
- 3.12.x
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v5
with:
python-version: '${{ matrix.python-version }}'

View file

@ -24,15 +24,15 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v5
uses: actions/checkout@v4
- name: Setup Node.js
uses: actions/setup-node@v5
uses: actions/setup-node@v4
with:
node-version: '22'
- name: Install Dependencies
run: npm install --force
run: npm install
- name: Format Frontend
run: npm run format
@ -51,15 +51,15 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v5
uses: actions/checkout@v4
- name: Setup Node.js
uses: actions/setup-node@v5
uses: actions/setup-node@v4
with:
node-version: '22'
- name: Install Dependencies
run: npm ci --force
run: npm ci
- name: Run vitest
run: npm run test:frontend

View file

@ -16,15 +16,15 @@ jobs:
id-token: write
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Install Git
run: sudo apt-get update && sudo apt-get install -y git
- uses: actions/setup-node@v5
- uses: actions/setup-node@v4
with:
node-version: 22
- uses: actions/setup-python@v6
- uses: actions/setup-python@v5
with:
python-version: 3.11
- name: Build

5
.gitignore vendored
View file

@ -1,5 +1,3 @@
x.py
yarn.lock
.DS_Store
node_modules
/build
@ -14,8 +12,7 @@ vite.config.ts.timestamp-*
__pycache__/
*.py[cod]
*$py.class
.nvmrc
CLAUDE.md
# C extensions
*.so

View file

@ -3,6 +3,8 @@ pnpm-lock.yaml
package-lock.json
yarn.lock
kubernetes/
# Copy of .gitignore
.DS_Store
node_modules

View file

@ -5,6 +5,5 @@
"printWidth": 100,
"plugins": ["prettier-plugin-svelte"],
"pluginSearchDirs": ["."],
"overrides": [{ "files": "*.svelte", "options": { "parser": "svelte" } }],
"endOfLine": "lf"
"overrides": [{ "files": "*.svelte", "options": { "parser": "svelte" } }]
}

File diff suppressed because it is too large Load diff

View file

@ -3,8 +3,6 @@
# use build args in the docker build command with --build-arg="BUILDARG=true"
ARG USE_CUDA=false
ARG USE_OLLAMA=false
ARG USE_SLIM=false
ARG USE_PERMISSION_HARDENING=false
# Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
ARG USE_CUDA_VER=cu128
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
@ -26,16 +24,13 @@ ARG GID=0
FROM --platform=$BUILDPLATFORM node:22-alpine3.20 AS build
ARG BUILD_HASH
# Set Node.js options (heap limit Allocation failed - JavaScript heap out of memory)
# ENV NODE_OPTIONS="--max-old-space-size=4096"
WORKDIR /app
# to store git revision in build
RUN apk add --no-cache git
COPY package.json package-lock.json ./
RUN npm ci --force
RUN npm ci
COPY . .
ENV APP_BUILD_HASH=${BUILD_HASH}
@ -48,23 +43,17 @@ FROM python:3.11-slim-bookworm AS base
ARG USE_CUDA
ARG USE_OLLAMA
ARG USE_CUDA_VER
ARG USE_SLIM
ARG USE_PERMISSION_HARDENING
ARG USE_EMBEDDING_MODEL
ARG USE_RERANKING_MODEL
ARG UID
ARG GID
# Python settings
ENV PYTHONUNBUFFERED=1
## Basis ##
ENV ENV=prod \
PORT=8080 \
# pass build args to the build
USE_OLLAMA_DOCKER=${USE_OLLAMA} \
USE_CUDA_DOCKER=${USE_CUDA} \
USE_SLIM_DOCKER=${USE_SLIM} \
USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \
USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL}
@ -119,13 +108,29 @@ RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry
# Make sure the user has access to the app and root directory
RUN chown -R $UID:$GID /app $HOME
# Install common system dependencies
RUN apt-get update && \
apt-get install -y --no-install-recommends \
git build-essential pandoc gcc netcat-openbsd curl jq \
python3-dev \
ffmpeg libsm6 libxext6 \
&& rm -rf /var/lib/apt/lists/*
RUN if [ "$USE_OLLAMA" = "true" ]; then \
apt-get update && \
# Install pandoc and netcat
apt-get install -y --no-install-recommends git build-essential pandoc netcat-openbsd curl && \
apt-get install -y --no-install-recommends gcc python3-dev && \
# for RAG OCR
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
# install helper tools
apt-get install -y --no-install-recommends curl jq && \
# install ollama
curl -fsSL https://ollama.com/install.sh | sh && \
# cleanup
rm -rf /var/lib/apt/lists/*; \
else \
apt-get update && \
# Install pandoc, netcat and gcc
apt-get install -y --no-install-recommends git build-essential pandoc gcc netcat-openbsd curl jq && \
apt-get install -y --no-install-recommends gcc python3-dev && \
# for RAG OCR
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
# cleanup
rm -rf /var/lib/apt/lists/*; \
fi
# install python dependencies
COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt
@ -141,22 +146,13 @@ RUN pip3 install --no-cache-dir uv && \
else \
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
uv pip install --system -r requirements.txt --no-cache-dir && \
if [ "$USE_SLIM" != "true" ]; then \
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
fi; \
fi; \
mkdir -p /app/backend/data && chown -R $UID:$GID /app/backend/data/ && \
rm -rf /var/lib/apt/lists/*;
chown -R $UID:$GID /app/backend/data/
# Install Ollama if requested
RUN if [ "$USE_OLLAMA" = "true" ]; then \
date +%s > /tmp/ollama_build_hash && \
echo "Cache broken at timestamp: `cat /tmp/ollama_build_hash`" && \
curl -fsSL https://ollama.com/install.sh | sh && \
rm -rf /var/lib/apt/lists/*; \
fi
# copy embedding weight from build
# RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
@ -174,17 +170,6 @@ EXPOSE 8080
HEALTHCHECK CMD curl --silent --fail http://localhost:${PORT:-8080}/health | jq -ne 'input.status == true' || exit 1
# Minimal, atomic permission hardening for OpenShift (arbitrary UID):
# - Group 0 owns /app and /root
# - Directories are group-writable and have SGID so new files inherit GID 0
RUN if [ "$USE_PERMISSION_HARDENING" = "true" ]; then \
set -eux; \
chgrp -R 0 /app /root || true; \
chmod -R g+rwX /app /root || true; \
find /app -type d -exec chmod g+s {} + || true; \
find /root -type d -exec chmod g+s {} + || true; \
fi
USER $UID:$GID
ARG BUILD_HASH

35
INSTALLATION.md Normal file
View file

@ -0,0 +1,35 @@
### Installing Both Ollama and Open WebUI Using Kustomize
For cpu-only pod
```bash
kubectl apply -f ./kubernetes/manifest/base
```
For gpu-enabled pod
```bash
kubectl apply -k ./kubernetes/manifest
```
### Installing Both Ollama and Open WebUI Using Helm
Package Helm file first
```bash
helm package ./kubernetes/helm/
```
For cpu-only pod
```bash
helm install ollama-webui ./ollama-webui-*.tgz
```
For gpu-enabled pod
```bash
helm install ollama-webui ./ollama-webui-*.tgz --set ollama.resources.limits.nvidia.com/gpu="1"
```
Check the `kubernetes/helm/values.yaml` file to know which parameters are available for customization

View file

@ -1,4 +1,4 @@
Copyright (c) 2023- Open WebUI Inc. [Created by Timothy Jaeryang Baek]
Copyright (c) 2023-2025 Timothy Jaeryang Baek (Open WebUI)
All rights reserved.
Redistribution and use in source and binary forms, with or without

View file

@ -1,53 +0,0 @@
All code and materials created before commit `60d84a3aae9802339705826e9095e272e3c83623` are subject to the following copyright and license:
Copyright (c) 2023-2025 Timothy Jaeryang Baek
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
All code and materials created before commit `a76068d69cd59568b920dfab85dc573dbbb8f131` are subject to the following copyright and license:
MIT License
Copyright (c) 2023 Timothy Jaeryang Baek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -1,11 +0,0 @@
# Open WebUI Multi-License Notice
This repository contains code governed by multiple licenses based on the date and origin of contribution:
1. All code committed prior to commit a76068d69cd59568b920dfab85dc573dbbb8f131 is licensed under the MIT License (see LICENSE_HISTORY).
2. All code committed from commit a76068d69cd59568b920dfab85dc573dbbb8f131 up to and including commit 60d84a3aae9802339705826e9095e272e3c83623 is licensed under the BSD 3-Clause License (see LICENSE_HISTORY).
3. All code contributed or modified after commit 60d84a3aae9802339705826e9095e272e3c83623 is licensed under the Open WebUI License (see LICENSE).
For details on which commits are covered by which license, refer to LICENSE_HISTORY.

View file

@ -10,16 +10,12 @@
[![Discord](https://img.shields.io/badge/Discord-Open_WebUI-blue?logo=discord&logoColor=white)](https://discord.gg/5rJgQTnV4s)
[![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/tjbck)
![Open WebUI Banner](./banner.png)
**Open WebUI is an [extensible](https://docs.openwebui.com/features/plugin/), feature-rich, and user-friendly self-hosted AI platform designed to operate entirely offline.** It supports various LLM runners like **Ollama** and **OpenAI-compatible APIs**, with **built-in inference engine** for RAG, making it a **powerful AI deployment solution**.
Passionate about open-source AI? [Join our team →](https://careers.openwebui.com/)
![Open WebUI Demo](./demo.png)
![Open WebUI Demo](./demo.gif)
> [!TIP]
> **Looking for an [Enterprise Plan](https://docs.openwebui.com/enterprise)?** **[Speak with Our Sales Team Today!](https://docs.openwebui.com/enterprise)**
> **Looking for an [Enterprise Plan](https://docs.openwebui.com/enterprise)?** **[Speak with Our Sales Team Today!](mailto:sales@openwebui.com)**
>
> Get **enhanced capabilities**, including **custom theming and branding**, **Service Level Agreement (SLA) support**, **Long-Term Support (LTS) versions**, and **more!**
@ -39,38 +35,24 @@ For more information, be sure to check out our [Open WebUI Documentation](https:
- ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction.
- 🎤📹 **Hands-Free Voice/Video Call**: Experience seamless communication with integrated hands-free voice and video call features using multiple Speech-to-Text providers (Local Whisper, OpenAI, Deepgram, Azure) and Text-to-Speech engines (Azure, ElevenLabs, OpenAI, Transformers, WebAPI), allowing for dynamic and interactive chat environments.
- 🎤📹 **Hands-Free Voice/Video Call**: Experience seamless communication with integrated hands-free voice and video call features, allowing for a more dynamic and interactive chat environment.
- 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration.
- 🐍 **Native Python Function Calling Tool**: Enhance your LLMs with built-in code editor support in the tools workspace. Bring Your Own Function (BYOF) by simply adding your pure Python functions, enabling seamless integration with LLMs.
- 💾 **Persistent Artifact Storage**: Built-in key-value storage API for artifacts, enabling features like journals, trackers, leaderboards, and collaborative tools with both personal and shared data scopes across sessions.
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support using your choice of 9 vector databases and multiple content extraction engines (Tika, Docling, Document Intelligence, Mistral OCR, External loaders). Load documents directly into chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
- 🔍 **Web Search for RAG**: Perform web searches using 15+ providers including `SearXNG`, `Google PSE`, `Brave Search`, `Kagi`, `Mojeek`, `Tavily`, `Perplexity`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `SearchApi`, `SerpApi`, `Bing`, `Jina`, `Exa`, `Sougou`, `Azure AI Search`, and `Ollama Cloud`, injecting results directly into your chat experience.
- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch`, `SearchApi` and `Bing` and inject the results directly into your chat experience.
- 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
- 🎨 **Image Generation & Editing Integration**: Create and edit images using multiple engines including OpenAI's DALL-E, Gemini, ComfyUI (local), and AUTOMATIC1111 (local), with support for both generation and prompt-based editing workflows.
- 🎨 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using options such as AUTOMATIC1111 API or ComfyUI (local), and OpenAI's DALL-E (external), enriching your chat experience with dynamic visual content.
- ⚙️ **Many Models Conversations**: Effortlessly engage with various models simultaneously, harnessing their unique strengths for optimal responses. Enhance your experience by leveraging a diverse set of models in parallel.
- 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
- 🗄️ **Flexible Database & Storage Options**: Choose from SQLite (with optional encryption), PostgreSQL, or configure cloud storage backends (S3, Google Cloud Storage, Azure Blob Storage) for scalable deployments.
- 🔍 **Advanced Vector Database Support**: Select from 9 vector database options including ChromaDB, PGVector, Qdrant, Milvus, Elasticsearch, OpenSearch, Pinecone, S3Vector, and Oracle 23ai for optimal RAG performance.
- 🔐 **Enterprise Authentication**: Full support for LDAP/Active Directory integration, SCIM 2.0 automated provisioning, and SSO via trusted headers alongside OAuth providers. Enterprise-grade user and group provisioning through SCIM 2.0 protocol, enabling seamless integration with identity providers like Okta, Azure AD, and Google Workspace for automated user lifecycle management.
- ☁️ **Cloud-Native Integration**: Native support for Google Drive and OneDrive/SharePoint file picking, enabling seamless document import from enterprise cloud storage.
- 📊 **Production Observability**: Built-in OpenTelemetry support for traces, metrics, and logs, enabling comprehensive monitoring with your existing observability stack.
- ⚖️ **Horizontal Scalability**: Redis-backed session management and WebSocket support for multi-worker and multi-node deployments behind load balancers.
- 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors!
- 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more.
@ -79,6 +61,33 @@ For more information, be sure to check out our [Open WebUI Documentation](https:
Want to learn more about Open WebUI's features? Check out our [Open WebUI documentation](https://docs.openwebui.com/features) for a comprehensive overview!
## Sponsors 🙌
#### Emerald
<table>
<tr>
<td>
<a href="https://n8n.io/" target="_blank">
<img src="https://docs.openwebui.com/sponsors/logos/n8n.png" alt="n8n" style="width: 8rem; height: 8rem; border-radius: .75rem;" />
</a>
</td>
<td>
N8N • Does your interface have a backend yet?<br>Try <a href="https://n8n.io/">n8n</a>
</td>
</tr>
<tr>
<td>
<a href="https://warp.dev/open-webui" target="_blank">
<img src="https://docs.openwebui.com/sponsors/logos/warp.png" alt="n8n" style="width: 8rem; height: 8rem; border-radius: .75rem;" />
</a>
</td>
<td>
Wrap • The intelligent terminal for developers
</td>
</tr>
</table>
---
We are incredibly grateful for the generous support of our sponsors. Their contributions help us to maintain and improve our project, ensuring we can continue to deliver quality work to our community. Thank you!
@ -172,8 +181,6 @@ After installation, you can access Open WebUI at [http://localhost:3000](http://
We offer various installation alternatives, including non-Docker native installation methods, Docker Compose, Kustomize, and Helm. Visit our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/) or join our [Discord community](https://discord.gg/5rJgQTnV4s) for comprehensive guidance.
Look at the [Local Development Guide](https://docs.openwebui.com/getting-started/advanced-topics/development) for instructions on setting up a local development environment.
### Troubleshooting
Encountering connection issues? Our [Open WebUI Documentation](https://docs.openwebui.com/troubleshooting/) has got you covered. For further assistance and to join our vibrant community, visit the [Open WebUI Discord](https://discord.gg/5rJgQTnV4s).
@ -190,6 +197,14 @@ docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=
### Keeping Your Docker Installation Up-to-Date
In case you want to update your local Docker installation to the latest version, you can do it with [Watchtower](https://containrrr.dev/watchtower/):
```bash
docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower --run-once open-webui
```
In the last part of the command, replace `open-webui` with your container name if it is different.
Check our Updating Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/updating).
### Using the Dev Branch 🌙
@ -217,7 +232,7 @@ Discover upcoming features on our roadmap in the [Open WebUI Documentation](http
## License 📜
This project contains code under multiple licenses. The current codebase includes components licensed under the Open WebUI License with an additional requirement to preserve the "Open WebUI" branding, as well as prior contributions under their respective original licenses. For a detailed record of license changes and the applicable terms for each section of the code, please refer to [LICENSE_HISTORY](./LICENSE_HISTORY). For complete and updated licensing details, please see the [LICENSE](./LICENSE) and [LICENSE_HISTORY](./LICENSE_HISTORY) files.
This project is licensed under the [Open WebUI License](LICENSE), a revised BSD-3-Clause license. You receive all the same rights as the classic BSD-3 license: you can use, modify, and distribute the software, including in proprietary and commercial products, with minimal restrictions. The only additional requirement is to preserve the "Open WebUI" branding, as detailed in the LICENSE file. For full terms, see the [LICENSE](LICENSE) document. 📄
## Support 💬

View file

@ -1,3 +1,2 @@
export CORS_ALLOW_ORIGIN="http://localhost:5173;http://localhost:8080"
PORT="${PORT:-8080}"
uvicorn open_webui.main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload
uvicorn open_webui.main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload

View file

@ -10,7 +10,7 @@ script_location = migrations
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = ..
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.

File diff suppressed because it is too large Load diff

View file

@ -38,14 +38,13 @@ class ERROR_MESSAGES(str, Enum):
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
MODEL_ID_TOO_LONG = "The model id is too long. Please make sure your model id is less than 256 characters long."
INVALID_TOKEN = (
"Your session has expired or the token is invalid. Please sign in again."
)
INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again."
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
INCORRECT_PASSWORD = (
INVALID_PASSWORD = (
"The password provided is incorrect. Please check for typos and try again."
)
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
@ -105,10 +104,6 @@ class ERROR_MESSAGES(str, Enum):
)
FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding."
INVALID_PASSWORD = lambda err="": (
err if err else "The password does not meet the required validation criteria."
)
class TASKS(str, Enum):
def __str__(self) -> str:
@ -116,7 +111,6 @@ class TASKS(str, Enum):
DEFAULT = lambda task="": f"{task if task else 'generation'}"
TITLE_GENERATION = "title_generation"
FOLLOW_UP_GENERATION = "follow_up_generation"
TAGS_GENERATION = "tags_generation"
EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "query_generation"

View file

@ -5,11 +5,7 @@ import os
import pkgutil
import sys
import shutil
from uuid import uuid4
from pathlib import Path
from cryptography.hazmat.primitives import serialization
import re
import markdown
from bs4 import BeautifulSoup
@ -19,17 +15,14 @@ from open_webui.constants import ERROR_MESSAGES
# Load .env file
####################################
# Use .resolve() to get the canonical path, removing any '..' or '.' components
ENV_FILE_PATH = Path(__file__).resolve()
OPEN_WEBUI_DIR = Path(__file__).parent # the path containing this file
print(OPEN_WEBUI_DIR)
# OPEN_WEBUI_DIR should be the directory where env.py resides (open_webui/)
OPEN_WEBUI_DIR = ENV_FILE_PATH.parent
BACKEND_DIR = OPEN_WEBUI_DIR.parent # the path containing this file
BASE_DIR = BACKEND_DIR.parent # the path containing the backend/
# BACKEND_DIR is the parent of OPEN_WEBUI_DIR (backend/)
BACKEND_DIR = OPEN_WEBUI_DIR.parent
# BASE_DIR is the parent of BACKEND_DIR (open-webui-dev/)
BASE_DIR = BACKEND_DIR.parent
print(BACKEND_DIR)
print(BASE_DIR)
try:
from dotenv import find_dotenv, load_dotenv
@ -85,7 +78,32 @@ if "cuda_error" in locals():
log.exception(cuda_error)
del cuda_error
SRC_LOG_LEVELS = {} # Legacy variable, do not remove
log_sources = [
"AUDIO",
"COMFYUI",
"CONFIG",
"DB",
"IMAGES",
"MAIN",
"MODELS",
"OLLAMA",
"OPENAI",
"RAG",
"WEBHOOK",
"SOCKET",
"OAUTH",
]
SRC_LOG_LEVELS = {}
for source in log_sources:
log_env_var = source + "_LOG_LEVEL"
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
if SRC_LOG_LEVELS[source] not in logging.getLevelNamesMapping():
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
@ -114,10 +132,6 @@ else:
VERSION = PACKAGE_DATA["version"]
DEPLOYMENT_ID = os.environ.get("DEPLOYMENT_ID", "")
INSTANCE_ID = os.environ.get("INSTANCE_ID", str(uuid4()))
# Function to parse each section
def parse_section(section):
items = []
@ -183,7 +197,6 @@ CHANGELOG = changelog_json
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
####################################
# ENABLE_FORWARD_USER_INFO_HEADERS
####################################
@ -192,11 +205,6 @@ ENABLE_FORWARD_USER_INFO_HEADERS = (
os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true"
)
# Experimental feature, may be removed in future
ENABLE_STAR_SESSIONS_MIDDLEWARE = (
os.environ.get("ENABLE_STAR_SESSIONS_MIDDLEWARE", "False").lower() == "true"
)
####################################
# WEBUI_BUILD_HASH
####################################
@ -256,43 +264,21 @@ else:
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
DATABASE_TYPE = os.environ.get("DATABASE_TYPE")
DATABASE_USER = os.environ.get("DATABASE_USER")
DATABASE_PASSWORD = os.environ.get("DATABASE_PASSWORD")
DATABASE_CRED = ""
if DATABASE_USER:
DATABASE_CRED += f"{DATABASE_USER}"
if DATABASE_PASSWORD:
DATABASE_CRED += f":{DATABASE_PASSWORD}"
DB_VARS = {
"db_type": DATABASE_TYPE,
"db_cred": DATABASE_CRED,
"db_host": os.environ.get("DATABASE_HOST"),
"db_port": os.environ.get("DATABASE_PORT"),
"db_name": os.environ.get("DATABASE_NAME"),
}
if all(DB_VARS.values()):
DATABASE_URL = f"{DB_VARS['db_type']}://{DB_VARS['db_cred']}@{DB_VARS['db_host']}:{DB_VARS['db_port']}/{DB_VARS['db_name']}"
elif DATABASE_TYPE == "sqlite+sqlcipher" and not os.environ.get("DATABASE_URL"):
# Handle SQLCipher with local file when DATABASE_URL wasn't explicitly set
DATABASE_URL = f"sqlite+sqlcipher:///{DATA_DIR}/webui.db"
# Replace the postgres:// with postgresql://
if "postgres://" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")
DATABASE_SCHEMA = os.environ.get("DATABASE_SCHEMA", None)
DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", None)
DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", 0)
if DATABASE_POOL_SIZE != None:
if DATABASE_POOL_SIZE == "":
DATABASE_POOL_SIZE = 0
else:
try:
DATABASE_POOL_SIZE = int(DATABASE_POOL_SIZE)
except Exception:
DATABASE_POOL_SIZE = None
DATABASE_POOL_SIZE = 0
DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0)
@ -324,26 +310,6 @@ else:
except Exception:
DATABASE_POOL_RECYCLE = 3600
DATABASE_ENABLE_SQLITE_WAL = (
os.environ.get("DATABASE_ENABLE_SQLITE_WAL", "False").lower() == "true"
)
DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = os.environ.get(
"DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL", None
)
if DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL is not None:
try:
DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = float(
DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
)
except Exception:
DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = 0.0
# Enable public visibility of active user count (when disabled, only admins can see it)
ENABLE_PUBLIC_ACTIVE_USERS_COUNT = (
os.environ.get("ENABLE_PUBLIC_ACTIVE_USERS_COUNT", "True").lower() == "true"
)
RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
@ -352,36 +318,14 @@ ENABLE_REALTIME_CHAT_SAVE = (
os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "False").lower() == "true"
)
ENABLE_QUERIES_CACHE = os.environ.get("ENABLE_QUERIES_CACHE", "False").lower() == "true"
####################################
# REDIS
####################################
REDIS_URL = os.environ.get("REDIS_URL", "")
REDIS_CLUSTER = os.environ.get("REDIS_CLUSTER", "False").lower() == "true"
REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui")
REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
# Maximum number of retries for Redis operations when using Sentinel fail-over
REDIS_SENTINEL_MAX_RETRY_COUNT = os.environ.get("REDIS_SENTINEL_MAX_RETRY_COUNT", "2")
try:
REDIS_SENTINEL_MAX_RETRY_COUNT = int(REDIS_SENTINEL_MAX_RETRY_COUNT)
if REDIS_SENTINEL_MAX_RETRY_COUNT < 1:
REDIS_SENTINEL_MAX_RETRY_COUNT = 2
except ValueError:
REDIS_SENTINEL_MAX_RETRY_COUNT = 2
REDIS_SOCKET_CONNECT_TIMEOUT = os.environ.get("REDIS_SOCKET_CONNECT_TIMEOUT", "")
try:
REDIS_SOCKET_CONNECT_TIMEOUT = float(REDIS_SOCKET_CONNECT_TIMEOUT)
except ValueError:
REDIS_SOCKET_CONNECT_TIMEOUT = None
####################################
# UVICORN WORKERS
####################################
@ -401,14 +345,6 @@ except ValueError:
####################################
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
ENABLE_INITIAL_ADMIN_SIGNUP = (
os.environ.get("ENABLE_INITIAL_ADMIN_SIGNUP", "False").lower() == "true"
)
ENABLE_SIGNUP_PASSWORD_CONFIRMATION = (
os.environ.get("ENABLE_SIGNUP_PASSWORD_CONFIRMATION", "False").lower() == "true"
)
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
@ -418,23 +354,6 @@ WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get(
)
ENABLE_PASSWORD_VALIDATION = (
os.environ.get("ENABLE_PASSWORD_VALIDATION", "False").lower() == "true"
)
PASSWORD_VALIDATION_REGEX_PATTERN = os.environ.get(
"PASSWORD_VALIDATION_REGEX_PATTERN",
"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$",
)
try:
PASSWORD_VALIDATION_REGEX_PATTERN = re.compile(PASSWORD_VALIDATION_REGEX_PATTERN)
except Exception as e:
log.error(f"Invalid PASSWORD_VALIDATION_REGEX_PATTERN: {e}")
PASSWORD_VALIDATION_REGEX_PATTERN = re.compile(
"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$"
)
BYPASS_MODEL_ACCESS_CONTROL = (
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
)
@ -475,197 +394,18 @@ WEBUI_AUTH_COOKIE_SECURE = (
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
ENABLE_COMPRESSION_MIDDLEWARE = (
os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true"
)
####################################
# OAUTH Configuration
####################################
ENABLE_OAUTH_EMAIL_FALLBACK = (
os.environ.get("ENABLE_OAUTH_EMAIL_FALLBACK", "False").lower() == "true"
)
ENABLE_OAUTH_ID_TOKEN_COOKIE = (
os.environ.get("ENABLE_OAUTH_ID_TOKEN_COOKIE", "True").lower() == "true"
)
OAUTH_CLIENT_INFO_ENCRYPTION_KEY = os.environ.get(
"OAUTH_CLIENT_INFO_ENCRYPTION_KEY", WEBUI_SECRET_KEY
)
OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get(
"OAUTH_SESSION_TOKEN_ENCRYPTION_KEY", WEBUI_SECRET_KEY
)
####################################
# SCIM Configuration
####################################
ENABLE_SCIM = (
os.environ.get("ENABLE_SCIM", os.environ.get("SCIM_ENABLED", "False")).lower()
== "true"
)
SCIM_TOKEN = os.environ.get("SCIM_TOKEN", "")
####################################
# LICENSE_KEY
####################################
LICENSE_KEY = os.environ.get("LICENSE_KEY", "")
LICENSE_BLOB = None
LICENSE_BLOB_PATH = os.environ.get("LICENSE_BLOB_PATH", DATA_DIR / "l.data")
if LICENSE_BLOB_PATH and os.path.exists(LICENSE_BLOB_PATH):
with open(LICENSE_BLOB_PATH, "rb") as f:
LICENSE_BLOB = f.read()
LICENSE_PUBLIC_KEY = os.environ.get("LICENSE_PUBLIC_KEY", "")
pk = None
if LICENSE_PUBLIC_KEY:
pk = serialization.load_pem_public_key(
f"""
-----BEGIN PUBLIC KEY-----
{LICENSE_PUBLIC_KEY}
-----END PUBLIC KEY-----
""".encode(
"utf-8"
)
)
####################################
# MODELS
####################################
ENABLE_CUSTOM_MODEL_FALLBACK = (
os.environ.get("ENABLE_CUSTOM_MODEL_FALLBACK", "False").lower() == "true"
)
MODELS_CACHE_TTL = os.environ.get("MODELS_CACHE_TTL", "1")
if MODELS_CACHE_TTL == "":
MODELS_CACHE_TTL = None
else:
try:
MODELS_CACHE_TTL = int(MODELS_CACHE_TTL)
except Exception:
MODELS_CACHE_TTL = 1
####################################
# CHAT
####################################
ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION = (
os.environ.get("ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION", "False").lower()
== "true"
)
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = os.environ.get(
"CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE", "1"
)
if CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE == "":
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = 1
else:
try:
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = int(
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE
)
except Exception:
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = 1
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = os.environ.get(
"CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES", "30"
)
if CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES == "":
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30
else:
try:
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = int(CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES)
except Exception:
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = os.environ.get(
"CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE", ""
)
if CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE == "":
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = None
else:
try:
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = int(
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE
)
except Exception:
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = None
####################################
# WEBSOCKET SUPPORT
####################################
ENABLE_WEBSOCKET_SUPPORT = (
os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true"
)
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_OPTIONS = os.environ.get("WEBSOCKET_REDIS_OPTIONS", "")
if WEBSOCKET_REDIS_OPTIONS == "":
if REDIS_SOCKET_CONNECT_TIMEOUT:
WEBSOCKET_REDIS_OPTIONS = {
"socket_connect_timeout": REDIS_SOCKET_CONNECT_TIMEOUT
}
else:
log.debug("No WEBSOCKET_REDIS_OPTIONS provided, defaulting to None")
WEBSOCKET_REDIS_OPTIONS = None
else:
try:
WEBSOCKET_REDIS_OPTIONS = json.loads(WEBSOCKET_REDIS_OPTIONS)
except Exception:
log.warning("Invalid WEBSOCKET_REDIS_OPTIONS, defaulting to None")
WEBSOCKET_REDIS_OPTIONS = None
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
WEBSOCKET_REDIS_CLUSTER = (
os.environ.get("WEBSOCKET_REDIS_CLUSTER", str(REDIS_CLUSTER)).lower() == "true"
)
websocket_redis_lock_timeout = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", "60")
try:
WEBSOCKET_REDIS_LOCK_TIMEOUT = int(websocket_redis_lock_timeout)
except ValueError:
WEBSOCKET_REDIS_LOCK_TIMEOUT = 60
WEBSOCKET_REDIS_LOCK_TIMEOUT = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", 60)
WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379")
WEBSOCKET_SERVER_LOGGING = (
os.environ.get("WEBSOCKET_SERVER_LOGGING", "False").lower() == "true"
)
WEBSOCKET_SERVER_ENGINEIO_LOGGING = (
os.environ.get("WEBSOCKET_SERVER_LOGGING", "False").lower() == "true"
)
WEBSOCKET_SERVER_PING_TIMEOUT = os.environ.get("WEBSOCKET_SERVER_PING_TIMEOUT", "20")
try:
WEBSOCKET_SERVER_PING_TIMEOUT = int(WEBSOCKET_SERVER_PING_TIMEOUT)
except ValueError:
WEBSOCKET_SERVER_PING_TIMEOUT = 20
WEBSOCKET_SERVER_PING_INTERVAL = os.environ.get("WEBSOCKET_SERVER_PING_INTERVAL", "25")
try:
WEBSOCKET_SERVER_PING_INTERVAL = int(WEBSOCKET_SERVER_PING_INTERVAL)
except ValueError:
WEBSOCKET_SERVER_PING_INTERVAL = 25
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
@ -764,32 +504,19 @@ else:
# OFFLINE_MODE
####################################
ENABLE_VERSION_UPDATE_CHECK = (
os.environ.get("ENABLE_VERSION_UPDATE_CHECK", "true").lower() == "true"
)
OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
if OFFLINE_MODE:
os.environ["HF_HUB_OFFLINE"] = "1"
ENABLE_VERSION_UPDATE_CHECK = False
####################################
# AUDIT LOGGING
####################################
# Where to store log file
# Defaults to the DATA_DIR/audit.log. To set AUDIT_LOGS_FILE_PATH you need to
# provide the whole path, like: /app/audit.log
AUDIT_LOGS_FILE_PATH = os.getenv("AUDIT_LOGS_FILE_PATH", f"{DATA_DIR}/audit.log")
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
# Maximum size of a file before rotating into a new log file
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
# Comma separated list of logger names to use for audit logging
# Default is "uvicorn.access" which is the access log for Uvicorn
# You can add more logger names to this list if you want to capture more logs
AUDIT_UVICORN_LOGGER_NAMES = os.getenv(
"AUDIT_UVICORN_LOGGER_NAMES", "uvicorn.access"
).split(",")
# METADATA | REQUEST | REQUEST_RESPONSE
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper()
try:
@ -810,34 +537,9 @@ AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
####################################
ENABLE_OTEL = os.environ.get("ENABLE_OTEL", "False").lower() == "true"
ENABLE_OTEL_TRACES = os.environ.get("ENABLE_OTEL_TRACES", "False").lower() == "true"
ENABLE_OTEL_METRICS = os.environ.get("ENABLE_OTEL_METRICS", "False").lower() == "true"
ENABLE_OTEL_LOGS = os.environ.get("ENABLE_OTEL_LOGS", "False").lower() == "true"
OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317"
)
OTEL_METRICS_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_METRICS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT
)
OTEL_LOGS_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_LOGS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT
)
OTEL_EXPORTER_OTLP_INSECURE = (
os.environ.get("OTEL_EXPORTER_OTLP_INSECURE", "False").lower() == "true"
)
OTEL_METRICS_EXPORTER_OTLP_INSECURE = (
os.environ.get(
"OTEL_METRICS_EXPORTER_OTLP_INSECURE", str(OTEL_EXPORTER_OTLP_INSECURE)
).lower()
== "true"
)
OTEL_LOGS_EXPORTER_OTLP_INSECURE = (
os.environ.get(
"OTEL_LOGS_EXPORTER_OTLP_INSECURE", str(OTEL_EXPORTER_OTLP_INSECURE)
).lower()
== "true"
)
OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui")
OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
"OTEL_RESOURCE_ATTRIBUTES", ""
@ -845,33 +547,6 @@ OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
OTEL_TRACES_SAMPLER = os.environ.get(
"OTEL_TRACES_SAMPLER", "parentbased_always_on"
).lower()
OTEL_BASIC_AUTH_USERNAME = os.environ.get("OTEL_BASIC_AUTH_USERNAME", "")
OTEL_BASIC_AUTH_PASSWORD = os.environ.get("OTEL_BASIC_AUTH_PASSWORD", "")
OTEL_METRICS_BASIC_AUTH_USERNAME = os.environ.get(
"OTEL_METRICS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME
)
OTEL_METRICS_BASIC_AUTH_PASSWORD = os.environ.get(
"OTEL_METRICS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD
)
OTEL_LOGS_BASIC_AUTH_USERNAME = os.environ.get(
"OTEL_LOGS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME
)
OTEL_LOGS_BASIC_AUTH_PASSWORD = os.environ.get(
"OTEL_LOGS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD
)
OTEL_OTLP_SPAN_EXPORTER = os.environ.get(
"OTEL_OTLP_SPAN_EXPORTER", "grpc"
).lower() # grpc or http
OTEL_METRICS_OTLP_SPAN_EXPORTER = os.environ.get(
"OTEL_METRICS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER
).lower() # grpc or http
OTEL_LOGS_OTLP_SPAN_EXPORTER = os.environ.get(
"OTEL_LOGS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER
).lower() # grpc or http
####################################
# TOOLS/FUNCTIONS PIP OPTIONS

View file

@ -19,14 +19,12 @@ from fastapi import (
from starlette.responses import Response, StreamingResponse
from open_webui.constants import ERROR_MESSAGES
from open_webui.socket.main import (
get_event_call,
get_event_emitter,
)
from open_webui.models.users import UserModel
from open_webui.models.functions import Functions
from open_webui.models.models import Models
@ -37,7 +35,7 @@ from open_webui.utils.plugin import (
from open_webui.utils.tools import get_tools
from open_webui.utils.access_control import has_access
from open_webui.env import GLOBAL_LOG_LEVEL
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
from open_webui.utils.misc import (
add_or_update_system_message,
@ -48,32 +46,21 @@ from open_webui.utils.misc import (
)
from open_webui.utils.payload import (
apply_model_params_to_body_openai,
apply_system_prompt_to_body,
apply_model_system_prompt_to_body,
)
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_function_module_by_id(request: Request, pipe_id: str):
function_module, _, _ = get_function_module_from_cache(request, pipe_id)
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
Valves = function_module.Valves
valves = Functions.get_function_valves_by_id(pipe_id)
if valves:
try:
function_module.valves = Valves(
**{k: v for k, v in valves.items() if v is not None}
)
except Exception as e:
log.exception(f"Error loading valves for function {pipe_id}: {e}")
raise e
else:
function_module.valves = Valves()
function_module.valves = function_module.Valves(**(valves if valves else {}))
return function_module
@ -82,75 +69,65 @@ async def get_function_models(request):
pipe_models = []
for pipe in pipes:
try:
function_module = get_function_module_by_id(request, pipe.id)
function_module = get_function_module_by_id(request, pipe.id)
has_user_valves = False
if hasattr(function_module, "UserValves"):
has_user_valves = True
# Check if function is a manifold
if hasattr(function_module, "pipes"):
sub_pipes = []
# Check if function is a manifold
if hasattr(function_module, "pipes"):
# Handle pipes being a list, sync function, or async function
try:
if callable(function_module.pipes):
if asyncio.iscoroutinefunction(function_module.pipes):
sub_pipes = await function_module.pipes()
else:
sub_pipes = function_module.pipes()
else:
sub_pipes = function_module.pipes
except Exception as e:
log.exception(e)
sub_pipes = []
# Handle pipes being a list, sync function, or async function
try:
if callable(function_module.pipes):
if asyncio.iscoroutinefunction(function_module.pipes):
sub_pipes = await function_module.pipes()
else:
sub_pipes = function_module.pipes()
else:
sub_pipes = function_module.pipes
except Exception as e:
log.exception(e)
sub_pipes = []
log.debug(
f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
)
log.debug(
f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
)
for p in sub_pipes:
sub_pipe_id = f'{pipe.id}.{p["id"]}'
sub_pipe_name = p["name"]
for p in sub_pipes:
sub_pipe_id = f'{pipe.id}.{p["id"]}'
sub_pipe_name = p["name"]
if hasattr(function_module, "name"):
sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
if hasattr(function_module, "name"):
sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
pipe_flag = {"type": pipe.type}
pipe_models.append(
{
"id": sub_pipe_id,
"name": sub_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
"has_user_valves": has_user_valves,
}
)
else:
pipe_flag = {"type": "pipe"}
log.debug(
f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
)
pipe_flag = {"type": pipe.type}
pipe_models.append(
{
"id": pipe.id,
"name": pipe.name,
"id": sub_pipe_id,
"name": sub_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
"has_user_valves": has_user_valves,
}
)
except Exception as e:
log.exception(e)
continue
else:
pipe_flag = {"type": "pipe"}
log.debug(
f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
)
pipe_models.append(
{
"id": pipe.id,
"name": pipe.name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
return pipe_models
@ -241,16 +218,6 @@ async def generate_function_chat_completion(
__task__ = metadata.get("task", None)
__task_body__ = metadata.get("task_body", None)
oauth_token = None
try:
if request.cookies.get("oauth_session_id", None):
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
except Exception as e:
log.error(f"Error getting OAuth token: {e}")
extra_params = {
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
@ -260,12 +227,16 @@ async def generate_function_chat_completion(
"__task__": __task__,
"__task_body__": __task_body__,
"__files__": files,
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__metadata__": metadata,
"__oauth_token__": oauth_token,
"__request__": request,
}
extra_params["__tools__"] = await get_tools(
extra_params["__tools__"] = get_tools(
request,
tool_ids,
user,
@ -282,11 +253,8 @@ async def generate_function_chat_completion(
form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump()
if params:
system = params.pop("system", None)
form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_system_prompt_to_body(system, form_data, metadata, user)
form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_model_system_prompt_to_body(params, form_data, metadata, user)
pipe_id = get_pipe_id(form_data)
function_module = get_function_module_by_id(request, pipe_id)

View file

@ -1,4 +1,3 @@
import os
import json
import logging
from contextlib import contextmanager
@ -9,14 +8,14 @@ from open_webui.env import (
OPEN_WEBUI_DIR,
DATABASE_URL,
DATABASE_SCHEMA,
SRC_LOG_LEVELS,
DATABASE_POOL_MAX_OVERFLOW,
DATABASE_POOL_RECYCLE,
DATABASE_POOL_SIZE,
DATABASE_POOL_TIMEOUT,
DATABASE_ENABLE_SQLITE_WAL,
)
from peewee_migrate import Router
from sqlalchemy import Dialect, create_engine, MetaData, event, types
from sqlalchemy import Dialect, create_engine, MetaData, types
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.pool import QueuePool, NullPool
@ -24,6 +23,7 @@ from sqlalchemy.sql.type_api import _T
from typing_extensions import Self
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
class JSONField(types.TypeDecorator):
@ -62,9 +62,6 @@ def handle_peewee_migration(DATABASE_URL):
except Exception as e:
log.error(f"Failed to initialize the database connection: {e}")
log.warning(
"Hint: If your database password contains special characters, you may need to URL-encode it."
)
raise
finally:
# Properly closing the database connection
@ -79,66 +76,25 @@ handle_peewee_migration(DATABASE_URL)
SQLALCHEMY_DATABASE_URL = DATABASE_URL
# Handle SQLCipher URLs
if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
database_password = os.environ.get("DATABASE_PASSWORD")
if not database_password or database_password.strip() == "":
raise ValueError(
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
)
# Extract database path from SQLCipher URL
db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "")
# Create a custom creator function that uses sqlcipher3
def create_sqlcipher_connection():
import sqlcipher3
conn = sqlcipher3.connect(db_path, check_same_thread=False)
conn.execute(f"PRAGMA key = '{database_password}'")
return conn
engine = create_engine(
"sqlite://", # Dummy URL since we're using creator
creator=create_sqlcipher_connection,
echo=False,
)
log.info("Connected to encrypted SQLite database using SQLCipher")
elif "sqlite" in SQLALCHEMY_DATABASE_URL:
if "sqlite" in SQLALCHEMY_DATABASE_URL:
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
def on_connect(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
if DATABASE_ENABLE_SQLITE_WAL:
cursor.execute("PRAGMA journal_mode=WAL")
else:
cursor.execute("PRAGMA journal_mode=DELETE")
cursor.close()
event.listen(engine, "connect", on_connect)
else:
if isinstance(DATABASE_POOL_SIZE, int):
if DATABASE_POOL_SIZE > 0:
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
pool_size=DATABASE_POOL_SIZE,
max_overflow=DATABASE_POOL_MAX_OVERFLOW,
pool_timeout=DATABASE_POOL_TIMEOUT,
pool_recycle=DATABASE_POOL_RECYCLE,
pool_pre_ping=True,
poolclass=QueuePool,
)
else:
engine = create_engine(
SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
)
if DATABASE_POOL_SIZE > 0:
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
pool_size=DATABASE_POOL_SIZE,
max_overflow=DATABASE_POOL_MAX_OVERFLOW,
pool_timeout=DATABASE_POOL_TIMEOUT,
pool_recycle=DATABASE_POOL_RECYCLE,
pool_pre_ping=True,
poolclass=QueuePool,
)
else:
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
engine = create_engine(
SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
)
SessionLocal = sessionmaker(

View file

@ -1,7 +1,7 @@
import logging
import os
from contextvars import ContextVar
from open_webui.env import SRC_LOG_LEVELS
from peewee import *
from peewee import InterfaceError as PeeWeeInterfaceError
from peewee import PostgresqlDatabase
@ -9,6 +9,7 @@ from playhouse.db_url import connect, parse
from playhouse.shortcuts import ReconnectMixin
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())
@ -42,45 +43,24 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
def register_connection(db_url):
# Check if using SQLCipher protocol
if db_url.startswith("sqlite+sqlcipher://"):
database_password = os.environ.get("DATABASE_PASSWORD")
if not database_password or database_password.strip() == "":
raise ValueError(
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
)
from playhouse.sqlcipher_ext import SqlCipherDatabase
# Parse the database path from SQLCipher URL
# Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite
db_path = db_url.replace("sqlite+sqlcipher://", "")
# Use Peewee's native SqlCipherDatabase with encryption
db = SqlCipherDatabase(db_path, passphrase=database_password)
db = connect(db_url, unquote_user=True, unquote_password=True)
if isinstance(db, PostgresqlDatabase):
# Enable autoconnect for SQLite databases, managed by Peewee
db.autoconnect = True
db.reuse_if_open = True
log.info("Connected to encrypted SQLite database using SQLCipher")
log.info("Connected to PostgreSQL database")
# Get the connection details
connection = parse(db_url, unquote_user=True, unquote_password=True)
# Use our custom database class that supports reconnection
db = ReconnectingPostgresqlDatabase(**connection)
db.connect(reuse_if_open=True)
elif isinstance(db, SqliteDatabase):
# Enable autoconnect for SQLite databases, managed by Peewee
db.autoconnect = True
db.reuse_if_open = True
log.info("Connected to SQLite database")
else:
# Standard database connection (existing logic)
db = connect(db_url, unquote_user=True, unquote_password=True)
if isinstance(db, PostgresqlDatabase):
# Enable autoconnect for SQLite databases, managed by Peewee
db.autoconnect = True
db.reuse_if_open = True
log.info("Connected to PostgreSQL database")
# Get the connection details
connection = parse(db_url, unquote_user=True, unquote_password=True)
# Use our custom database class that supports reconnection
db = ReconnectingPostgresqlDatabase(**connection)
db.connect(reuse_if_open=True)
elif isinstance(db, SqliteDatabase):
# Enable autoconnect for SQLite databases, managed by Peewee
db.autoconnect = True
db.reuse_if_open = True
log.info("Connected to SQLite database")
else:
raise ValueError("Unsupported database connection")
raise ValueError("Unsupported database connection")
return db

File diff suppressed because it is too large Load diff

View file

@ -2,8 +2,8 @@ from logging.config import fileConfig
from alembic import context
from open_webui.models.auths import Auth
from open_webui.env import DATABASE_URL, DATABASE_PASSWORD
from sqlalchemy import engine_from_config, pool, create_engine
from open_webui.env import DATABASE_URL
from sqlalchemy import engine_from_config, pool
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
@ -62,38 +62,11 @@ def run_migrations_online() -> None:
and associate a connection with the context.
"""
# Handle SQLCipher URLs
if DB_URL and DB_URL.startswith("sqlite+sqlcipher://"):
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "":
raise ValueError(
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
)
# Extract database path from SQLCipher URL
db_path = DB_URL.replace("sqlite+sqlcipher://", "")
if db_path.startswith("/"):
db_path = db_path[1:] # Remove leading slash for relative paths
# Create a custom creator function that uses sqlcipher3
def create_sqlcipher_connection():
import sqlcipher3
conn = sqlcipher3.connect(db_path, check_same_thread=False)
conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'")
return conn
connectable = create_engine(
"sqlite://", # Dummy URL since we're using creator
creator=create_sqlcipher_connection,
echo=False,
)
else:
# Standard database connection (existing logic)
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)

View file

@ -1,46 +0,0 @@
"""Add indexes
Revision ID: 018012973d35
Revises: d31026856c01
Create Date: 2025-08-13 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "018012973d35"
down_revision = "d31026856c01"
branch_labels = None
depends_on = None
def upgrade():
# Chat table indexes
op.create_index("folder_id_idx", "chat", ["folder_id"])
op.create_index("user_id_pinned_idx", "chat", ["user_id", "pinned"])
op.create_index("user_id_archived_idx", "chat", ["user_id", "archived"])
op.create_index("updated_at_user_id_idx", "chat", ["updated_at", "user_id"])
op.create_index("folder_id_user_id_idx", "chat", ["folder_id", "user_id"])
# Tag table index
op.create_index("user_id_idx", "tag", ["user_id"])
# Function table index
op.create_index("is_global_idx", "function", ["is_global"])
def downgrade():
# Chat table indexes
op.drop_index("folder_id_idx", table_name="chat")
op.drop_index("user_id_pinned_idx", table_name="chat")
op.drop_index("user_id_archived_idx", table_name="chat")
op.drop_index("updated_at_user_id_idx", table_name="chat")
op.drop_index("folder_id_user_id_idx", table_name="chat")
# Tag table index
op.drop_index("user_id_idx", table_name="tag")
# Function table index
op.drop_index("is_global_idx", table_name="function")

View file

@ -1,103 +0,0 @@
"""Update messages and channel member table
Revision ID: 2f1211949ecc
Revises: 37f288994c47
Create Date: 2025-11-27 03:07:56.200231
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
# revision identifiers, used by Alembic.
revision: str = "2f1211949ecc"
down_revision: Union[str, None] = "37f288994c47"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# New columns to be added to channel_member table
op.add_column("channel_member", sa.Column("status", sa.Text(), nullable=True))
op.add_column(
"channel_member",
sa.Column(
"is_active",
sa.Boolean(),
nullable=False,
default=True,
server_default=sa.sql.expression.true(),
),
)
op.add_column(
"channel_member",
sa.Column(
"is_channel_muted",
sa.Boolean(),
nullable=False,
default=False,
server_default=sa.sql.expression.false(),
),
)
op.add_column(
"channel_member",
sa.Column(
"is_channel_pinned",
sa.Boolean(),
nullable=False,
default=False,
server_default=sa.sql.expression.false(),
),
)
op.add_column("channel_member", sa.Column("data", sa.JSON(), nullable=True))
op.add_column("channel_member", sa.Column("meta", sa.JSON(), nullable=True))
op.add_column(
"channel_member", sa.Column("joined_at", sa.BigInteger(), nullable=False)
)
op.add_column(
"channel_member", sa.Column("left_at", sa.BigInteger(), nullable=True)
)
op.add_column(
"channel_member", sa.Column("last_read_at", sa.BigInteger(), nullable=True)
)
op.add_column(
"channel_member", sa.Column("updated_at", sa.BigInteger(), nullable=True)
)
# New columns to be added to message table
op.add_column(
"message",
sa.Column(
"is_pinned",
sa.Boolean(),
nullable=False,
default=False,
server_default=sa.sql.expression.false(),
),
)
op.add_column("message", sa.Column("pinned_at", sa.BigInteger(), nullable=True))
op.add_column("message", sa.Column("pinned_by", sa.Text(), nullable=True))
def downgrade() -> None:
op.drop_column("channel_member", "updated_at")
op.drop_column("channel_member", "last_read_at")
op.drop_column("channel_member", "meta")
op.drop_column("channel_member", "data")
op.drop_column("channel_member", "is_channel_pinned")
op.drop_column("channel_member", "is_channel_muted")
op.drop_column("message", "pinned_by")
op.drop_column("message", "pinned_at")
op.drop_column("message", "is_pinned")

View file

@ -1,146 +0,0 @@
"""add_group_member_table
Revision ID: 37f288994c47
Revises: a5c220713937
Create Date: 2025-11-17 03:45:25.123939
"""
import uuid
import time
import json
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "37f288994c47"
down_revision: Union[str, None] = "a5c220713937"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# 1. Create new table
op.create_table(
"group_member",
sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False),
sa.Column(
"group_id",
sa.Text(),
sa.ForeignKey("group.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"user_id",
sa.Text(),
sa.ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.UniqueConstraint("group_id", "user_id", name="uq_group_member_group_user"),
)
connection = op.get_bind()
# 2. Read existing group with user_ids JSON column
group_table = sa.Table(
"group",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("user_ids", sa.JSON()), # JSON stored as text in SQLite + PG
)
results = connection.execute(
sa.select(group_table.c.id, group_table.c.user_ids)
).fetchall()
print(results)
# 3. Insert members into group_member table
gm_table = sa.Table(
"group_member",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("group_id", sa.Text()),
sa.Column("user_id", sa.Text()),
sa.Column("created_at", sa.BigInteger()),
sa.Column("updated_at", sa.BigInteger()),
)
now = int(time.time())
for group_id, user_ids in results:
if not user_ids:
continue
if isinstance(user_ids, str):
try:
user_ids = json.loads(user_ids)
except Exception:
continue # skip invalid JSON
if not isinstance(user_ids, list):
continue
rows = [
{
"id": str(uuid.uuid4()),
"group_id": group_id,
"user_id": uid,
"created_at": now,
"updated_at": now,
}
for uid in user_ids
]
if rows:
connection.execute(gm_table.insert(), rows)
# 4. Optionally drop the old column
with op.batch_alter_table("group") as batch:
batch.drop_column("user_ids")
def downgrade():
# Reverse: restore user_ids column
with op.batch_alter_table("group") as batch:
batch.add_column(sa.Column("user_ids", sa.JSON()))
connection = op.get_bind()
gm_table = sa.Table(
"group_member",
sa.MetaData(),
sa.Column("group_id", sa.Text()),
sa.Column("user_id", sa.Text()),
sa.Column("created_at", sa.BigInteger()),
sa.Column("updated_at", sa.BigInteger()),
)
group_table = sa.Table(
"group",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("user_ids", sa.JSON()),
)
# Build JSON arrays again
results = connection.execute(sa.select(group_table.c.id)).fetchall()
for (group_id,) in results:
members = connection.execute(
sa.select(gm_table.c.user_id).where(gm_table.c.group_id == group_id)
).fetchall()
member_ids = [m[0] for m in members]
connection.execute(
group_table.update()
.where(group_table.c.id == group_id)
.values(user_ids=member_ids)
)
# Drop the new table
op.drop_table("group_member")

View file

@ -1,80 +0,0 @@
"""Add oauth_session table
Revision ID: 38d63c18f30f
Revises: 3af16a1c9fb6
Create Date: 2025-09-08 14:19:59.583921
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "38d63c18f30f"
down_revision: Union[str, None] = "3af16a1c9fb6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Ensure 'id' column in 'user' table is unique and primary key (ForeignKey constraint)
inspector = sa.inspect(op.get_bind())
columns = inspector.get_columns("user")
pk_columns = inspector.get_pk_constraint("user")["constrained_columns"]
id_column = next((col for col in columns if col["name"] == "id"), None)
if id_column and not id_column.get("unique", False):
unique_constraints = inspector.get_unique_constraints("user")
unique_columns = {tuple(u["column_names"]) for u in unique_constraints}
with op.batch_alter_table("user") as batch_op:
# If primary key is wrong, drop it
if pk_columns and pk_columns != ["id"]:
batch_op.drop_constraint(
inspector.get_pk_constraint("user")["name"], type_="primary"
)
# Add unique constraint if missing
if ("id",) not in unique_columns:
batch_op.create_unique_constraint("uq_user_id", ["id"])
# Re-create correct primary key
batch_op.create_primary_key("pk_user_id", ["id"])
# Create oauth_session table
op.create_table(
"oauth_session",
sa.Column("id", sa.Text(), primary_key=True, nullable=False, unique=True),
sa.Column(
"user_id",
sa.Text(),
sa.ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("provider", sa.Text(), nullable=False),
sa.Column("token", sa.Text(), nullable=False),
sa.Column("expires_at", sa.BigInteger(), nullable=False),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
)
# Create indexes for better performance
op.create_index("idx_oauth_session_user_id", "oauth_session", ["user_id"])
op.create_index("idx_oauth_session_expires_at", "oauth_session", ["expires_at"])
op.create_index(
"idx_oauth_session_user_provider", "oauth_session", ["user_id", "provider"]
)
def downgrade() -> None:
# Drop indexes first
op.drop_index("idx_oauth_session_user_provider", table_name="oauth_session")
op.drop_index("idx_oauth_session_expires_at", table_name="oauth_session")
op.drop_index("idx_oauth_session_user_id", table_name="oauth_session")
# Drop the table
op.drop_table("oauth_session")

View file

@ -1,32 +0,0 @@
"""update user table
Revision ID: 3af16a1c9fb6
Revises: 018012973d35
Create Date: 2025-08-21 02:07:18.078283
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "3af16a1c9fb6"
down_revision: Union[str, None] = "018012973d35"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column("user", sa.Column("username", sa.String(length=50), nullable=True))
op.add_column("user", sa.Column("bio", sa.Text(), nullable=True))
op.add_column("user", sa.Column("gender", sa.Text(), nullable=True))
op.add_column("user", sa.Column("date_of_birth", sa.Date(), nullable=True))
def downgrade() -> None:
op.drop_column("user", "username")
op.drop_column("user", "bio")
op.drop_column("user", "gender")
op.drop_column("user", "date_of_birth")

View file

@ -1,169 +0,0 @@
"""Add knowledge_file table
Revision ID: 3e0e00844bb0
Revises: 90ef40d4714e
Create Date: 2025-12-02 06:54:19.401334
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy import inspect
import open_webui.internal.db
import time
import json
import uuid
# revision identifiers, used by Alembic.
revision: str = "3e0e00844bb0"
down_revision: Union[str, None] = "90ef40d4714e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"knowledge_file",
sa.Column("id", sa.Text(), primary_key=True),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column(
"knowledge_id",
sa.Text(),
sa.ForeignKey("knowledge.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"file_id",
sa.Text(),
sa.ForeignKey("file.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
# indexes
sa.Index("ix_knowledge_file_knowledge_id", "knowledge_id"),
sa.Index("ix_knowledge_file_file_id", "file_id"),
sa.Index("ix_knowledge_file_user_id", "user_id"),
# unique constraints
sa.UniqueConstraint(
"knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file"
), # prevent duplicate entries
)
connection = op.get_bind()
# 2. Read existing group with user_ids JSON column
knowledge_table = sa.Table(
"knowledge",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("user_id", sa.Text()),
sa.Column("data", sa.JSON()), # JSON stored as text in SQLite + PG
)
results = connection.execute(
sa.select(
knowledge_table.c.id, knowledge_table.c.user_id, knowledge_table.c.data
)
).fetchall()
# 3. Insert members into group_member table
kf_table = sa.Table(
"knowledge_file",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("user_id", sa.Text()),
sa.Column("knowledge_id", sa.Text()),
sa.Column("file_id", sa.Text()),
sa.Column("created_at", sa.BigInteger()),
sa.Column("updated_at", sa.BigInteger()),
)
file_table = sa.Table(
"file",
sa.MetaData(),
sa.Column("id", sa.Text()),
)
now = int(time.time())
for knowledge_id, user_id, data in results:
if not data:
continue
if isinstance(data, str):
try:
data = json.loads(data)
except Exception:
continue # skip invalid JSON
if not isinstance(data, dict):
continue
file_ids = data.get("file_ids", [])
for file_id in file_ids:
file_exists = connection.execute(
sa.select(file_table.c.id).where(file_table.c.id == file_id)
).fetchone()
if not file_exists:
continue # skip non-existing files
row = {
"id": str(uuid.uuid4()),
"user_id": user_id,
"knowledge_id": knowledge_id,
"file_id": file_id,
"created_at": now,
"updated_at": now,
}
connection.execute(kf_table.insert().values(**row))
with op.batch_alter_table("knowledge") as batch:
batch.drop_column("data")
def downgrade() -> None:
# 1. Add back the old data column
op.add_column("knowledge", sa.Column("data", sa.JSON(), nullable=True))
connection = op.get_bind()
# 2. Read knowledge_file entries and reconstruct data JSON
knowledge_table = sa.Table(
"knowledge",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("data", sa.JSON()),
)
kf_table = sa.Table(
"knowledge_file",
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("knowledge_id", sa.Text()),
sa.Column("file_id", sa.Text()),
)
results = connection.execute(sa.select(knowledge_table.c.id)).fetchall()
for (knowledge_id,) in results:
file_ids = connection.execute(
sa.select(kf_table.c.file_id).where(kf_table.c.knowledge_id == knowledge_id)
).fetchall()
file_ids_list = [fid for (fid,) in file_ids]
data_json = {"file_ids": file_ids_list}
connection.execute(
knowledge_table.update()
.where(knowledge_table.c.id == knowledge_id)
.values(data=data_json)
)
# 3. Drop the knowledge_file table
op.drop_table("knowledge_file")

View file

@ -1,54 +0,0 @@
"""Add channel file table
Revision ID: 6283dc0e4d8d
Revises: 3e0e00844bb0
Create Date: 2025-12-10 15:11:39.424601
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
# revision identifiers, used by Alembic.
revision: str = "6283dc0e4d8d"
down_revision: Union[str, None] = "3e0e00844bb0"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"channel_file",
sa.Column("id", sa.Text(), primary_key=True),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column(
"channel_id",
sa.Text(),
sa.ForeignKey("channel.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"file_id",
sa.Text(),
sa.ForeignKey("file.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
# indexes
sa.Index("ix_channel_file_channel_id", "channel_id"),
sa.Index("ix_channel_file_file_id", "file_id"),
sa.Index("ix_channel_file_user_id", "user_id"),
# unique constraints
sa.UniqueConstraint(
"channel_id", "file_id", name="uq_channel_file_channel_file"
), # prevent duplicate entries
)
def downgrade() -> None:
op.drop_table("channel_file")

View file

@ -1,49 +0,0 @@
"""Update channel file and knowledge table
Revision ID: 81cc2ce44d79
Revises: 6283dc0e4d8d
Create Date: 2025-12-10 16:07:58.001282
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
# revision identifiers, used by Alembic.
revision: str = "81cc2ce44d79"
down_revision: Union[str, None] = "6283dc0e4d8d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add message_id column to channel_file table
with op.batch_alter_table("channel_file", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"message_id",
sa.Text(),
sa.ForeignKey(
"message.id", ondelete="CASCADE", name="fk_channel_file_message_id"
),
nullable=True,
)
)
# Add data column to knowledge table
with op.batch_alter_table("knowledge", schema=None) as batch_op:
batch_op.add_column(sa.Column("data", sa.JSON(), nullable=True))
def downgrade() -> None:
# Remove message_id column from channel_file table
with op.batch_alter_table("channel_file", schema=None) as batch_op:
batch_op.drop_column("message_id")
# Remove data column from knowledge table
with op.batch_alter_table("knowledge", schema=None) as batch_op:
batch_op.drop_column("data")

View file

@ -1,81 +0,0 @@
"""Update channel and channel members table
Revision ID: 90ef40d4714e
Revises: b10670c03dd5
Create Date: 2025-11-30 06:33:38.790341
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
# revision identifiers, used by Alembic.
revision: str = "90ef40d4714e"
down_revision: Union[str, None] = "b10670c03dd5"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Update 'channel' table
op.add_column("channel", sa.Column("is_private", sa.Boolean(), nullable=True))
op.add_column("channel", sa.Column("archived_at", sa.BigInteger(), nullable=True))
op.add_column("channel", sa.Column("archived_by", sa.Text(), nullable=True))
op.add_column("channel", sa.Column("deleted_at", sa.BigInteger(), nullable=True))
op.add_column("channel", sa.Column("deleted_by", sa.Text(), nullable=True))
op.add_column("channel", sa.Column("updated_by", sa.Text(), nullable=True))
# Update 'channel_member' table
op.add_column("channel_member", sa.Column("role", sa.Text(), nullable=True))
op.add_column("channel_member", sa.Column("invited_by", sa.Text(), nullable=True))
op.add_column(
"channel_member", sa.Column("invited_at", sa.BigInteger(), nullable=True)
)
# Create 'channel_webhook' table
op.create_table(
"channel_webhook",
sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column(
"channel_id",
sa.Text(),
sa.ForeignKey("channel.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("profile_image_url", sa.Text(), nullable=True),
sa.Column("token", sa.Text(), nullable=False),
sa.Column("last_used_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
)
pass
def downgrade() -> None:
# Downgrade 'channel' table
op.drop_column("channel", "is_private")
op.drop_column("channel", "archived_at")
op.drop_column("channel", "archived_by")
op.drop_column("channel", "deleted_at")
op.drop_column("channel", "deleted_by")
op.drop_column("channel", "updated_by")
# Downgrade 'channel_member' table
op.drop_column("channel_member", "role")
op.drop_column("channel_member", "invited_by")
op.drop_column("channel_member", "invited_at")
# Drop 'channel_webhook' table
op.drop_table("channel_webhook")
pass

View file

@ -1,34 +0,0 @@
"""Add reply_to_id column to message
Revision ID: a5c220713937
Revises: 38d63c18f30f
Create Date: 2025-09-27 02:24:18.058455
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "a5c220713937"
down_revision: Union[str, None] = "38d63c18f30f"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add 'reply_to_id' column to the 'message' table for replying to messages
op.add_column(
"message",
sa.Column("reply_to_id", sa.Text(), nullable=True),
)
pass
def downgrade() -> None:
# Remove 'reply_to_id' column from the 'message' table
op.drop_column("message", "reply_to_id")
pass

View file

@ -1,251 +0,0 @@
"""Update user table
Revision ID: b10670c03dd5
Revises: 2f1211949ecc
Create Date: 2025-11-28 04:55:31.737538
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
import json
import time
# revision identifiers, used by Alembic.
revision: str = "b10670c03dd5"
down_revision: Union[str, None] = "2f1211949ecc"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def _drop_sqlite_indexes_for_column(table_name, column_name, conn):
"""
SQLite requires manual removal of any indexes referencing a column
before ALTER TABLE ... DROP COLUMN can succeed.
"""
indexes = conn.execute(sa.text(f"PRAGMA index_list('{table_name}')")).fetchall()
for idx in indexes:
index_name = idx[1] # index name
# Get indexed columns
idx_info = conn.execute(
sa.text(f"PRAGMA index_info('{index_name}')")
).fetchall()
indexed_cols = [row[2] for row in idx_info] # col names
if column_name in indexed_cols:
conn.execute(sa.text(f"DROP INDEX IF EXISTS {index_name}"))
def _convert_column_to_json(table: str, column: str):
conn = op.get_bind()
dialect = conn.dialect.name
# SQLite cannot ALTER COLUMN → must recreate column
if dialect == "sqlite":
# 1. Add temporary column
op.add_column(table, sa.Column(f"{column}_json", sa.JSON(), nullable=True))
# 2. Load old data
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
for row in rows:
uid, raw = row
if raw is None:
parsed = None
else:
try:
parsed = json.loads(raw)
except Exception:
parsed = None # fallback safe behavior
conn.execute(
sa.text(f'UPDATE "{table}" SET {column}_json = :val WHERE id = :id'),
{"val": json.dumps(parsed) if parsed else None, "id": uid},
)
# 3. Drop old TEXT column
op.drop_column(table, column)
# 4. Rename new JSON column → original name
op.alter_column(table, f"{column}_json", new_column_name=column)
else:
# PostgreSQL supports direct CAST
op.alter_column(
table,
column,
type_=sa.JSON(),
postgresql_using=f"{column}::json",
)
def _convert_column_to_text(table: str, column: str):
conn = op.get_bind()
dialect = conn.dialect.name
if dialect == "sqlite":
op.add_column(table, sa.Column(f"{column}_text", sa.Text(), nullable=True))
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
for uid, raw in rows:
conn.execute(
sa.text(f'UPDATE "{table}" SET {column}_text = :val WHERE id = :id'),
{"val": json.dumps(raw) if raw else None, "id": uid},
)
op.drop_column(table, column)
op.alter_column(table, f"{column}_text", new_column_name=column)
else:
op.alter_column(
table,
column,
type_=sa.Text(),
postgresql_using=f"to_json({column})::text",
)
def upgrade() -> None:
op.add_column(
"user", sa.Column("profile_banner_image_url", sa.Text(), nullable=True)
)
op.add_column("user", sa.Column("timezone", sa.String(), nullable=True))
op.add_column("user", sa.Column("presence_state", sa.String(), nullable=True))
op.add_column("user", sa.Column("status_emoji", sa.String(), nullable=True))
op.add_column("user", sa.Column("status_message", sa.Text(), nullable=True))
op.add_column(
"user", sa.Column("status_expires_at", sa.BigInteger(), nullable=True)
)
op.add_column("user", sa.Column("oauth", sa.JSON(), nullable=True))
# Convert info (TEXT/JSONField) → JSON
_convert_column_to_json("user", "info")
# Convert settings (TEXT/JSONField) → JSON
_convert_column_to_json("user", "settings")
op.create_table(
"api_key",
sa.Column("id", sa.Text(), primary_key=True, unique=True),
sa.Column("user_id", sa.Text(), sa.ForeignKey("user.id", ondelete="CASCADE")),
sa.Column("key", sa.Text(), unique=True, nullable=False),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("expires_at", sa.BigInteger(), nullable=True),
sa.Column("last_used_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
)
conn = op.get_bind()
users = conn.execute(
sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL')
).fetchall()
for uid, oauth_sub in users:
if oauth_sub:
# Example formats supported:
# provider@sub
# plain sub (stored as {"oidc": {"sub": sub}})
if "@" in oauth_sub:
provider, sub = oauth_sub.split("@", 1)
else:
provider, sub = "oidc", oauth_sub
oauth_json = json.dumps({provider: {"sub": sub}})
conn.execute(
sa.text('UPDATE "user" SET oauth = :oauth WHERE id = :id'),
{"oauth": oauth_json, "id": uid},
)
users_with_keys = conn.execute(
sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL')
).fetchall()
now = int(time.time())
for uid, api_key in users_with_keys:
if api_key:
conn.execute(
sa.text(
"""
INSERT INTO api_key (id, user_id, key, created_at, updated_at)
VALUES (:id, :user_id, :key, :created_at, :updated_at)
"""
),
{
"id": f"key_{uid}",
"user_id": uid,
"key": api_key,
"created_at": now,
"updated_at": now,
},
)
if conn.dialect.name == "sqlite":
_drop_sqlite_indexes_for_column("user", "api_key", conn)
_drop_sqlite_indexes_for_column("user", "oauth_sub", conn)
with op.batch_alter_table("user") as batch_op:
batch_op.drop_column("api_key")
batch_op.drop_column("oauth_sub")
def downgrade() -> None:
# --- 1. Restore old oauth_sub column ---
op.add_column("user", sa.Column("oauth_sub", sa.Text(), nullable=True))
conn = op.get_bind()
users = conn.execute(
sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL')
).fetchall()
for uid, oauth in users:
try:
data = json.loads(oauth)
provider = list(data.keys())[0]
sub = data[provider].get("sub")
oauth_sub = f"{provider}@{sub}"
except Exception:
oauth_sub = None
conn.execute(
sa.text('UPDATE "user" SET oauth_sub = :oauth_sub WHERE id = :id'),
{"oauth_sub": oauth_sub, "id": uid},
)
op.drop_column("user", "oauth")
# --- 2. Restore api_key field ---
op.add_column("user", sa.Column("api_key", sa.String(), nullable=True))
# Restore values from api_key
keys = conn.execute(sa.text("SELECT user_id, key FROM api_key")).fetchall()
for uid, key in keys:
conn.execute(
sa.text('UPDATE "user" SET api_key = :key WHERE id = :id'),
{"key": key, "id": uid},
)
# Drop new table
op.drop_table("api_key")
with op.batch_alter_table("user") as batch_op:
batch_op.drop_column("profile_banner_image_url")
batch_op.drop_column("timezone")
batch_op.drop_column("presence_state")
batch_op.drop_column("status_emoji")
batch_op.drop_column("status_message")
batch_op.drop_column("status_expires_at")
# Convert info (JSON) → TEXT
_convert_column_to_text("user", "info")
# Convert settings (JSON) → TEXT
_convert_column_to_text("user", "settings")

View file

@ -1,57 +0,0 @@
"""Add chat_file table
Revision ID: c440947495f3
Revises: 81cc2ce44d79
Create Date: 2025-12-21 20:27:41.694897
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "c440947495f3"
down_revision: Union[str, None] = "81cc2ce44d79"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"chat_file",
sa.Column("id", sa.Text(), primary_key=True),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column(
"chat_id",
sa.Text(),
sa.ForeignKey("chat.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"file_id",
sa.Text(),
sa.ForeignKey("file.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("message_id", sa.Text(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
# indexes
sa.Index("ix_chat_file_chat_id", "chat_id"),
sa.Index("ix_chat_file_file_id", "file_id"),
sa.Index("ix_chat_file_message_id", "message_id"),
sa.Index("ix_chat_file_user_id", "user_id"),
# unique constraints
sa.UniqueConstraint(
"chat_id", "file_id", name="uq_chat_file_chat_file"
), # prevent duplicate entries
)
pass
def downgrade() -> None:
op.drop_table("chat_file")
pass

View file

@ -1,23 +0,0 @@
"""Update folder table data
Revision ID: d31026856c01
Revises: 9f0c9cd09105
Create Date: 2025-07-13 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "d31026856c01"
down_revision = "9f0c9cd09105"
branch_labels = None
depends_on = None
def upgrade():
op.add_column("folder", sa.Column("data", sa.JSON(), nullable=True))
def downgrade():
op.drop_column("folder", "data")

View file

@ -3,11 +3,14 @@ import uuid
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.users import UserModel, UserProfileImageResponse, Users
from open_webui.models.users import UserModel, Users
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel
from sqlalchemy import Boolean, Column, String, Text
from open_webui.utils.auth import verify_password
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# DB MODEL
@ -17,7 +20,7 @@ log = logging.getLogger(__name__)
class Auth(Base):
__tablename__ = "auth"
id = Column(String, primary_key=True, unique=True)
id = Column(String, primary_key=True)
email = Column(String)
password = Column(Text)
active = Column(Boolean)
@ -44,7 +47,15 @@ class ApiKey(BaseModel):
api_key: Optional[str] = None
class SigninResponse(Token, UserProfileImageResponse):
class UserResponse(BaseModel):
id: str
email: str
name: str
role: str
profile_image_url: str
class SigninResponse(Token, UserResponse):
pass
@ -62,6 +73,11 @@ class ProfileImageUrlForm(BaseModel):
profile_image_url: str
class UpdateProfileForm(BaseModel):
profile_image_url: str
name: str
class UpdatePasswordForm(BaseModel):
password: str
new_password: str
@ -86,7 +102,7 @@ class AuthsTable:
name: str,
profile_image_url: str = "/user.png",
role: str = "pending",
oauth: Optional[dict] = None,
oauth_sub: Optional[str] = None,
) -> Optional[UserModel]:
with get_db() as db:
log.info("insert_new_auth")
@ -100,7 +116,7 @@ class AuthsTable:
db.add(result)
user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth=oauth
id, name, email, profile_image_url, role, oauth_sub
)
db.commit()
@ -111,9 +127,7 @@ class AuthsTable:
else:
return None
def authenticate_user(
self, email: str, verify_password: callable
) -> Optional[UserModel]:
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}")
user = Users.get_user_by_email(email)
@ -124,7 +138,7 @@ class AuthsTable:
with get_db() as db:
auth = db.query(Auth).filter_by(id=user.id, active=True).first()
if auth:
if verify_password(auth.password):
if verify_password(password, auth.password):
return user
else:
return None

View file

@ -4,24 +4,10 @@ import uuid
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.groups import Groups
from open_webui.utils.access_control import has_access
from pydantic import BaseModel, ConfigDict
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy import (
BigInteger,
Boolean,
Column,
ForeignKey,
String,
Text,
JSON,
UniqueConstraint,
case,
cast,
)
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
@ -33,30 +19,19 @@ from sqlalchemy.sql import exists
class Channel(Base):
__tablename__ = "channel"
id = Column(Text, primary_key=True, unique=True)
id = Column(Text, primary_key=True)
user_id = Column(Text)
type = Column(Text, nullable=True)
name = Column(Text)
description = Column(Text, nullable=True)
# Used to indicate if the channel is private (for 'group' type channels)
is_private = Column(Boolean, nullable=True)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
access_control = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
updated_by = Column(Text, nullable=True)
archived_at = Column(BigInteger, nullable=True)
archived_by = Column(Text, nullable=True)
deleted_at = Column(BigInteger, nullable=True)
deleted_by = Column(Text, nullable=True)
class ChannelModel(BaseModel):
@ -64,157 +39,17 @@ class ChannelModel(BaseModel):
id: str
user_id: str
type: Optional[str] = None
name: str
description: Optional[str] = None
is_private: Optional[bool] = None
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
created_at: int # timestamp in epoch (time_ns)
updated_at: int # timestamp in epoch (time_ns)
updated_by: Optional[str] = None
archived_at: Optional[int] = None # timestamp in epoch (time_ns)
archived_by: Optional[str] = None
deleted_at: Optional[int] = None # timestamp in epoch (time_ns)
deleted_by: Optional[str] = None
class ChannelMember(Base):
__tablename__ = "channel_member"
id = Column(Text, primary_key=True, unique=True)
channel_id = Column(Text, nullable=False)
user_id = Column(Text, nullable=False)
role = Column(Text, nullable=True)
status = Column(Text, nullable=True)
is_active = Column(Boolean, nullable=False, default=True)
is_channel_muted = Column(Boolean, nullable=False, default=False)
is_channel_pinned = Column(Boolean, nullable=False, default=False)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
invited_at = Column(BigInteger, nullable=True)
invited_by = Column(Text, nullable=True)
joined_at = Column(BigInteger)
left_at = Column(BigInteger, nullable=True)
last_read_at = Column(BigInteger, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class ChannelMemberModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
channel_id: str
user_id: str
role: Optional[str] = None
status: Optional[str] = None
is_active: bool = True
is_channel_muted: bool = False
is_channel_pinned: bool = False
data: Optional[dict] = None
meta: Optional[dict] = None
invited_at: Optional[int] = None # timestamp in epoch (time_ns)
invited_by: Optional[str] = None
joined_at: Optional[int] = None # timestamp in epoch (time_ns)
left_at: Optional[int] = None # timestamp in epoch (time_ns)
last_read_at: Optional[int] = None # timestamp in epoch (time_ns)
created_at: Optional[int] = None # timestamp in epoch (time_ns)
updated_at: Optional[int] = None # timestamp in epoch (time_ns)
class ChannelFile(Base):
__tablename__ = "channel_file"
id = Column(Text, unique=True, primary_key=True)
user_id = Column(Text, nullable=False)
channel_id = Column(
Text, ForeignKey("channel.id", ondelete="CASCADE"), nullable=False
)
message_id = Column(
Text, ForeignKey("message.id", ondelete="CASCADE"), nullable=True
)
file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
__table_args__ = (
UniqueConstraint("channel_id", "file_id", name="uq_channel_file_channel_file"),
)
class ChannelFileModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
channel_id: str
file_id: str
user_id: str
created_at: int # timestamp in epoch (time_ns)
updated_at: int # timestamp in epoch (time_ns)
class ChannelWebhook(Base):
__tablename__ = "channel_webhook"
id = Column(Text, primary_key=True, unique=True)
channel_id = Column(Text, nullable=False)
user_id = Column(Text, nullable=False)
name = Column(Text, nullable=False)
profile_image_url = Column(Text, nullable=True)
token = Column(Text, nullable=False)
last_used_at = Column(BigInteger, nullable=True)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
class ChannelWebhookModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
channel_id: str
user_id: str
name: str
profile_image_url: Optional[str] = None
token: str
last_used_at: Optional[int] = None # timestamp in epoch (time_ns)
created_at: int # timestamp in epoch (time_ns)
updated_at: int # timestamp in epoch (time_ns)
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
####################
@ -222,95 +57,23 @@ class ChannelWebhookModel(BaseModel):
####################
class ChannelResponse(ChannelModel):
is_manager: bool = False
write_access: bool = False
user_count: Optional[int] = None
class ChannelForm(BaseModel):
name: str = ""
name: str
description: Optional[str] = None
is_private: Optional[bool] = None
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
group_ids: Optional[list[str]] = None
user_ids: Optional[list[str]] = None
class CreateChannelForm(ChannelForm):
type: Optional[str] = None
class ChannelTable:
def _collect_unique_user_ids(
self,
invited_by: str,
user_ids: Optional[list[str]] = None,
group_ids: Optional[list[str]] = None,
) -> set[str]:
"""
Collect unique user ids from:
- invited_by
- user_ids
- each group in group_ids
Returns a set for efficient SQL diffing.
"""
users = set(user_ids or [])
users.add(invited_by)
for group_id in group_ids or []:
users.update(Groups.get_group_user_ids_by_id(group_id))
return users
def _create_membership_models(
self,
channel_id: str,
invited_by: str,
user_ids: set[str],
) -> list[ChannelMember]:
"""
Takes a set of NEW user IDs (already filtered to exclude existing members).
Returns ORM ChannelMember objects to be added.
"""
now = int(time.time_ns())
memberships = []
for uid in user_ids:
model = ChannelMemberModel(
**{
"id": str(uuid.uuid4()),
"channel_id": channel_id,
"user_id": uid,
"status": "joined",
"is_active": True,
"is_channel_muted": False,
"is_channel_pinned": False,
"invited_at": now,
"invited_by": invited_by,
"joined_at": now,
"left_at": None,
"last_read_at": now,
"created_at": now,
"updated_at": now,
}
)
memberships.append(ChannelMember(**model.model_dump()))
return memberships
def insert_new_channel(
self, form_data: CreateChannelForm, user_id: str
self, type: Optional[str], form_data: ChannelForm, user_id: str
) -> Optional[ChannelModel]:
with get_db() as db:
channel = ChannelModel(
**{
**form_data.model_dump(),
"type": form_data.type if form_data.type else None,
"type": type,
"name": form_data.name.lower(),
"id": str(uuid.uuid4()),
"user_id": user_id,
@ -318,21 +81,9 @@ class ChannelTable:
"updated_at": int(time.time_ns()),
}
)
new_channel = Channel(**channel.model_dump())
if form_data.type in ["group", "dm"]:
users = self._collect_unique_user_ids(
invited_by=user_id,
user_ids=form_data.user_ids,
group_ids=form_data.group_ids,
)
memberships = self._create_membership_models(
channel_id=new_channel.id,
invited_by=user_id,
user_ids=users,
)
db.add_all(memberships)
db.add(new_channel)
db.commit()
return channel
@ -342,481 +93,22 @@ class ChannelTable:
channels = db.query(Channel).all()
return [ChannelModel.model_validate(channel) for channel in channels]
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
group_ids = filter.get("group_ids", [])
user_id = filter.get("user_id")
dialect_name = db.bind.dialect.name
# Public access
conditions = []
if group_ids or user_id:
conditions.extend(
[
Channel.access_control.is_(None),
cast(Channel.access_control, String) == "null",
]
)
# User-level permission
if user_id:
conditions.append(Channel.user_id == user_id)
# Group-level permission
if group_ids:
group_conditions = []
for gid in group_ids:
if dialect_name == "sqlite":
group_conditions.append(
Channel.access_control[permission]["group_ids"].contains([gid])
)
elif dialect_name == "postgresql":
group_conditions.append(
cast(
Channel.access_control[permission]["group_ids"],
JSONB,
).contains([gid])
)
conditions.append(or_(*group_conditions))
if conditions:
query = query.filter(or_(*conditions))
return query
def get_channels_by_user_id(self, user_id: str) -> list[ChannelModel]:
with get_db() as db:
user_group_ids = [
group.id for group in Groups.get_groups_by_member_id(user_id)
]
membership_channels = (
db.query(Channel)
.join(ChannelMember, Channel.id == ChannelMember.channel_id)
.filter(
Channel.deleted_at.is_(None),
Channel.archived_at.is_(None),
Channel.type.in_(["group", "dm"]),
ChannelMember.user_id == user_id,
ChannelMember.is_active.is_(True),
)
.all()
)
query = db.query(Channel).filter(
Channel.deleted_at.is_(None),
Channel.archived_at.is_(None),
or_(
Channel.type.is_(None), # True NULL/None
Channel.type == "", # Empty string
and_(Channel.type != "group", Channel.type != "dm"),
),
)
query = self._has_permission(
db, query, {"user_id": user_id, "group_ids": user_group_ids}
)
standard_channels = query.all()
all_channels = membership_channels + standard_channels
return [ChannelModel.model_validate(c) for c in all_channels]
def get_dm_channel_by_user_ids(self, user_ids: list[str]) -> Optional[ChannelModel]:
with get_db() as db:
# Ensure uniqueness in case a list with duplicates is passed
unique_user_ids = list(set(user_ids))
match_count = func.sum(
case(
(ChannelMember.user_id.in_(unique_user_ids), 1),
else_=0,
)
)
subquery = (
db.query(ChannelMember.channel_id)
.group_by(ChannelMember.channel_id)
# 1. Channel must have exactly len(user_ids) members
.having(func.count(ChannelMember.user_id) == len(unique_user_ids))
# 2. All those members must be in unique_user_ids
.having(match_count == len(unique_user_ids))
.subquery()
)
channel = (
db.query(Channel)
.filter(
Channel.id.in_(subquery),
Channel.type == "dm",
)
.first()
)
return ChannelModel.model_validate(channel) if channel else None
def add_members_to_channel(
self,
channel_id: str,
invited_by: str,
user_ids: Optional[list[str]] = None,
group_ids: Optional[list[str]] = None,
) -> list[ChannelMemberModel]:
with get_db() as db:
# 1. Collect all user_ids including groups + inviter
requested_users = self._collect_unique_user_ids(
invited_by, user_ids, group_ids
)
existing_users = {
row.user_id
for row in db.query(ChannelMember.user_id)
.filter(ChannelMember.channel_id == channel_id)
.all()
}
new_user_ids = requested_users - existing_users
if not new_user_ids:
return [] # Nothing to add
new_memberships = self._create_membership_models(
channel_id, invited_by, new_user_ids
)
db.add_all(new_memberships)
db.commit()
return [
ChannelMemberModel.model_validate(membership)
for membership in new_memberships
]
def remove_members_from_channel(
self,
channel_id: str,
user_ids: list[str],
) -> int:
with get_db() as db:
result = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id.in_(user_ids),
)
.delete(synchronize_session=False)
)
db.commit()
return result # number of rows deleted
def is_user_channel_manager(self, channel_id: str, user_id: str) -> bool:
with get_db() as db:
# Check if the user is the creator of the channel
# or has a 'manager' role in ChannelMember
channel = db.query(Channel).filter(Channel.id == channel_id).first()
if channel and channel.user_id == user_id:
return True
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
ChannelMember.role == "manager",
)
.first()
)
return membership is not None
def join_channel(
self, channel_id: str, user_id: str
) -> Optional[ChannelMemberModel]:
with get_db() as db:
# Check if the membership already exists
existing_membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
)
.first()
)
if existing_membership:
return ChannelMemberModel.model_validate(existing_membership)
# Create new membership
channel_member = ChannelMemberModel(
**{
"id": str(uuid.uuid4()),
"channel_id": channel_id,
"user_id": user_id,
"status": "joined",
"is_active": True,
"is_channel_muted": False,
"is_channel_pinned": False,
"joined_at": int(time.time_ns()),
"left_at": None,
"last_read_at": int(time.time_ns()),
"created_at": int(time.time_ns()),
"updated_at": int(time.time_ns()),
}
)
new_membership = ChannelMember(**channel_member.model_dump())
db.add(new_membership)
db.commit()
return channel_member
def leave_channel(self, channel_id: str, user_id: str) -> bool:
with get_db() as db:
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
)
.first()
)
if not membership:
return False
membership.status = "left"
membership.is_active = False
membership.left_at = int(time.time_ns())
membership.updated_at = int(time.time_ns())
db.commit()
return True
def get_member_by_channel_and_user_id(
self, channel_id: str, user_id: str
) -> Optional[ChannelMemberModel]:
with get_db() as db:
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
)
.first()
)
return ChannelMemberModel.model_validate(membership) if membership else None
def get_members_by_channel_id(self, channel_id: str) -> list[ChannelMemberModel]:
with get_db() as db:
memberships = (
db.query(ChannelMember)
.filter(ChannelMember.channel_id == channel_id)
.all()
)
return [
ChannelMemberModel.model_validate(membership)
for membership in memberships
]
def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool) -> bool:
with get_db() as db:
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
)
.first()
)
if not membership:
return False
membership.is_channel_pinned = is_pinned
membership.updated_at = int(time.time_ns())
db.commit()
return True
def update_member_last_read_at(self, channel_id: str, user_id: str) -> bool:
with get_db() as db:
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
)
.first()
)
if not membership:
return False
membership.last_read_at = int(time.time_ns())
membership.updated_at = int(time.time_ns())
db.commit()
return True
def update_member_active_status(
self, channel_id: str, user_id: str, is_active: bool
) -> bool:
with get_db() as db:
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
)
.first()
)
if not membership:
return False
membership.is_active = is_active
membership.updated_at = int(time.time_ns())
db.commit()
return True
def is_user_channel_member(self, channel_id: str, user_id: str) -> bool:
with get_db() as db:
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
)
.first()
)
return membership is not None
def get_channels_by_user_id(
self, user_id: str, permission: str = "read"
) -> list[ChannelModel]:
channels = self.get_channels()
return [
channel
for channel in channels
if channel.user_id == user_id
or has_access(user_id, permission, channel.access_control)
]
def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
with get_db() as db:
channel = db.query(Channel).filter(Channel.id == id).first()
return ChannelModel.model_validate(channel) if channel else None
def get_channels_by_file_id(self, file_id: str) -> list[ChannelModel]:
with get_db() as db:
channel_files = (
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
)
channel_ids = [cf.channel_id for cf in channel_files]
channels = db.query(Channel).filter(Channel.id.in_(channel_ids)).all()
return [ChannelModel.model_validate(channel) for channel in channels]
def get_channels_by_file_id_and_user_id(
self, file_id: str, user_id: str
) -> list[ChannelModel]:
with get_db() as db:
# 1. Determine which channels have this file
channel_file_rows = (
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
)
channel_ids = [row.channel_id for row in channel_file_rows]
if not channel_ids:
return []
# 2. Load all channel rows that still exist
channels = (
db.query(Channel)
.filter(
Channel.id.in_(channel_ids),
Channel.deleted_at.is_(None),
Channel.archived_at.is_(None),
)
.all()
)
if not channels:
return []
# Preload user's group membership
user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id)]
allowed_channels = []
for channel in channels:
# --- Case A: group or dm => user must be an active member ---
if channel.type in ["group", "dm"]:
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == channel.id,
ChannelMember.user_id == user_id,
ChannelMember.is_active.is_(True),
)
.first()
)
if membership:
allowed_channels.append(ChannelModel.model_validate(channel))
continue
# --- Case B: standard channel => rely on ACL permissions ---
query = db.query(Channel).filter(Channel.id == channel.id)
query = self._has_permission(
db,
query,
{"user_id": user_id, "group_ids": user_group_ids},
permission="read",
)
allowed = query.first()
if allowed:
allowed_channels.append(ChannelModel.model_validate(allowed))
return allowed_channels
def get_channel_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[ChannelModel]:
with get_db() as db:
# Fetch the channel
channel: Channel = (
db.query(Channel)
.filter(
Channel.id == id,
Channel.deleted_at.is_(None),
Channel.archived_at.is_(None),
)
.first()
)
if not channel:
return None
# If the channel is a group or dm, read access requires membership (active)
if channel.type in ["group", "dm"]:
membership = (
db.query(ChannelMember)
.filter(
ChannelMember.channel_id == id,
ChannelMember.user_id == user_id,
ChannelMember.is_active.is_(True),
)
.first()
)
if membership:
return ChannelModel.model_validate(channel)
else:
return None
# For channels that are NOT group/dm, fall back to ACL-based read access
query = db.query(Channel).filter(Channel.id == id)
# Determine user groups
user_group_ids = [
group.id for group in Groups.get_groups_by_member_id(user_id)
]
# Apply ACL rules
query = self._has_permission(
db,
query,
{"user_id": user_id, "group_ids": user_group_ids},
permission="read",
)
channel_allowed = query.first()
return (
ChannelModel.model_validate(channel_allowed)
if channel_allowed
else None
)
def update_channel_by_id(
self, id: str, form_data: ChannelForm
) -> Optional[ChannelModel]:
@ -826,77 +118,14 @@ class ChannelTable:
return None
channel.name = form_data.name
channel.description = form_data.description
channel.is_private = form_data.is_private
channel.data = form_data.data
channel.meta = form_data.meta
channel.access_control = form_data.access_control
channel.updated_at = int(time.time_ns())
db.commit()
return ChannelModel.model_validate(channel) if channel else None
def add_file_to_channel_by_id(
self, channel_id: str, file_id: str, user_id: str
) -> Optional[ChannelFileModel]:
with get_db() as db:
channel_file = ChannelFileModel(
**{
"id": str(uuid.uuid4()),
"channel_id": channel_id,
"file_id": file_id,
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
result = ChannelFile(**channel_file.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return ChannelFileModel.model_validate(result)
else:
return None
except Exception:
return None
def set_file_message_id_in_channel_by_id(
self, channel_id: str, file_id: str, message_id: str
) -> bool:
try:
with get_db() as db:
channel_file = (
db.query(ChannelFile)
.filter_by(channel_id=channel_id, file_id=file_id)
.first()
)
if not channel_file:
return False
channel_file.message_id = message_id
channel_file.updated_at = int(time.time())
db.commit()
return True
except Exception:
return False
def remove_file_from_channel_by_id(self, channel_id: str, file_id: str) -> bool:
try:
with get_db() as db:
db.query(ChannelFile).filter_by(
channel_id=channel_id, file_id=file_id
).delete()
db.commit()
return True
except Exception:
return False
def delete_channel_by_id(self, id: str):
with get_db() as db:
db.query(Channel).filter(Channel.id == id).delete()

View file

@ -6,36 +6,25 @@ from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags
from open_webui.models.folders import Folders
from open_webui.utils.misc import sanitize_data_for_db, sanitize_text_for_db
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import (
BigInteger,
Boolean,
Column,
ForeignKey,
String,
Text,
JSON,
Index,
UniqueConstraint,
)
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
from sqlalchemy.sql.expression import bindparam
####################
# Chat DB Schema
####################
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Chat(Base):
__tablename__ = "chat"
id = Column(String, primary_key=True, unique=True)
id = Column(String, primary_key=True)
user_id = Column(String)
title = Column(Text)
chat = Column(JSON)
@ -50,20 +39,6 @@ class Chat(Base):
meta = Column(JSON, server_default="{}")
folder_id = Column(Text, nullable=True)
__table_args__ = (
# Performance indexes for common queries
# WHERE folder_id = ...
Index("folder_id_idx", "folder_id"),
# WHERE user_id = ... AND pinned = ...
Index("user_id_pinned_idx", "user_id", "pinned"),
# WHERE user_id = ... AND archived = ...
Index("user_id_archived_idx", "user_id", "archived"),
# WHERE user_id = ... ORDER BY updated_at DESC
Index("updated_at_user_id_idx", "updated_at", "user_id"),
# WHERE folder_id = ... AND user_id = ...
Index("folder_id_user_id_idx", "folder_id", "user_id"),
)
class ChatModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
@ -84,38 +59,6 @@ class ChatModel(BaseModel):
folder_id: Optional[str] = None
class ChatFile(Base):
__tablename__ = "chat_file"
id = Column(Text, unique=True, primary_key=True)
user_id = Column(Text, nullable=False)
chat_id = Column(Text, ForeignKey("chat.id", ondelete="CASCADE"), nullable=False)
message_id = Column(Text, nullable=True)
file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
__table_args__ = (
UniqueConstraint("chat_id", "file_id", name="uq_chat_file_chat_file"),
)
class ChatFileModel(BaseModel):
id: str
user_id: str
chat_id: str
message_id: Optional[str] = None
file_id: str
created_at: int
updated_at: int
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
@ -123,18 +66,12 @@ class ChatFileModel(BaseModel):
class ChatForm(BaseModel):
chat: dict
folder_id: Optional[str] = None
class ChatImportForm(ChatForm):
meta: Optional[dict] = {}
pinned: Optional[bool] = False
created_at: Optional[int] = None
updated_at: Optional[int] = None
class ChatsImportForm(BaseModel):
chats: list[ChatImportForm]
folder_id: Optional[str] = None
class ChatTitleMessagesForm(BaseModel):
@ -167,77 +104,7 @@ class ChatTitleIdResponse(BaseModel):
created_at: int
class ChatListResponse(BaseModel):
items: list[ChatModel]
total: int
class ChatUsageStatsResponse(BaseModel):
id: str # chat id
models: dict = {} # models used in the chat with their usage counts
message_count: int # number of messages in the chat
history_models: dict = {} # models used in the chat history with their usage counts
history_message_count: int # number of messages in the chat history
history_user_message_count: int # number of user messages in the chat history
history_assistant_message_count: (
int # number of assistant messages in the chat history
)
average_response_time: (
float # average response time of assistant messages in seconds
)
average_user_message_content_length: (
float # average length of user message contents
)
average_assistant_message_content_length: (
float # average length of assistant message contents
)
tags: list[str] = [] # tags associated with the chat
last_message_at: int # timestamp of the last message
updated_at: int
created_at: int
model_config = ConfigDict(extra="allow")
class ChatUsageStatsListResponse(BaseModel):
items: list[ChatUsageStatsResponse]
total: int
model_config = ConfigDict(extra="allow")
class ChatTable:
def _clean_null_bytes(self, obj):
"""Recursively remove null bytes from strings in dict/list structures."""
return sanitize_data_for_db(obj)
def _sanitize_chat_row(self, chat_item):
"""
Clean a Chat SQLAlchemy model's title + chat JSON,
and return True if anything changed.
"""
changed = False
# Clean title
if chat_item.title:
cleaned = self._clean_null_bytes(chat_item.title)
if cleaned != chat_item.title:
chat_item.title = cleaned
changed = True
# Clean JSON
if chat_item.chat:
cleaned = self._clean_null_bytes(chat_item.chat)
if cleaned != chat_item.chat:
chat_item.chat = cleaned
changed = True
return changed
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
with get_db() as db:
id = str(uuid.uuid4())
@ -245,76 +112,59 @@ class ChatTable:
**{
"id": id,
"user_id": user_id,
"title": self._clean_null_bytes(
"title": (
form_data.chat["title"]
if "title" in form_data.chat
else "New Chat"
),
"chat": self._clean_null_bytes(form_data.chat),
"chat": form_data.chat,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
result = Chat(**chat.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def import_chat(
self, user_id: str, form_data: ChatImportForm
) -> Optional[ChatModel]:
with get_db() as db:
id = str(uuid.uuid4())
chat = ChatModel(
**{
"id": id,
"user_id": user_id,
"title": (
form_data.chat["title"]
if "title" in form_data.chat
else "New Chat"
),
"chat": form_data.chat,
"meta": form_data.meta,
"pinned": form_data.pinned,
"folder_id": form_data.folder_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
chat_item = Chat(**chat.model_dump())
db.add(chat_item)
result = Chat(**chat.model_dump())
db.add(result)
db.commit()
db.refresh(chat_item)
return ChatModel.model_validate(chat_item) if chat_item else None
def _chat_import_form_to_chat_model(
self, user_id: str, form_data: ChatImportForm
) -> ChatModel:
id = str(uuid.uuid4())
chat = ChatModel(
**{
"id": id,
"user_id": user_id,
"title": self._clean_null_bytes(
form_data.chat["title"] if "title" in form_data.chat else "New Chat"
),
"chat": self._clean_null_bytes(form_data.chat),
"meta": form_data.meta,
"pinned": form_data.pinned,
"folder_id": form_data.folder_id,
"created_at": (
form_data.created_at if form_data.created_at else int(time.time())
),
"updated_at": (
form_data.updated_at if form_data.updated_at else int(time.time())
),
}
)
return chat
def import_chats(
self, user_id: str, chat_import_forms: list[ChatImportForm]
) -> list[ChatModel]:
with get_db() as db:
chats = []
for form_data in chat_import_forms:
chat = self._chat_import_form_to_chat_model(user_id, form_data)
chats.append(Chat(**chat.model_dump()))
db.add_all(chats)
db.commit()
return [ChatModel.model_validate(chat) for chat in chats]
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try:
with get_db() as db:
chat_item = db.get(Chat, id)
chat_item.chat = self._clean_null_bytes(chat)
chat_item.title = (
self._clean_null_bytes(chat["title"])
if "title" in chat
else "New Chat"
)
chat_item.chat = chat
chat_item.title = chat["title"] if "title" in chat else "New Chat"
chat_item.updated_at = int(time.time())
db.commit()
db.refresh(chat_item)
@ -359,7 +209,7 @@ class ChatTable:
return chat.chat.get("title", "New Chat")
def get_messages_map_by_chat_id(self, id: str) -> Optional[dict]:
def get_messages_by_chat_id(self, id: str) -> Optional[dict]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
@ -382,10 +232,6 @@ class ChatTable:
if chat is None:
return None
# Sanitize message content for null characters before upserting
if isinstance(message.get("content"), str):
message["content"] = sanitize_text_for_db(message["content"])
chat = chat.chat
history = chat.get("history", {})
@ -420,27 +266,6 @@ class ChatTable:
chat["history"] = history
return self.update_chat_by_id(id, chat)
def add_message_files_by_id_and_message_id(
self, id: str, message_id: str, files: list[dict]
) -> list[dict]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
history = chat.get("history", {})
message_files = []
if message_id in history.get("messages", {}):
message_files = history["messages"][message_id].get("files", [])
message_files = message_files + files
history["messages"][message_id]["files"] = message_files
chat["history"] = history
self.update_chat_by_id(id, chat)
return message_files
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_db() as db:
# Get the existing chat to share
@ -455,9 +280,6 @@ class ChatTable:
"user_id": f"shared-{chat_id}",
"title": chat.title,
"chat": chat.chat,
"meta": chat.meta,
"pinned": chat.pinned,
"folder_id": chat.folder_id,
"created_at": chat.created_at,
"updated_at": int(time.time()),
}
@ -489,9 +311,7 @@ class ChatTable:
shared_chat.title = chat.title
shared_chat.chat = chat.chat
shared_chat.meta = chat.meta
shared_chat.pinned = chat.pinned
shared_chat.folder_id = chat.folder_id
shared_chat.updated_at = int(time.time())
db.commit()
db.refresh(shared_chat)
@ -510,15 +330,6 @@ class ChatTable:
except Exception:
return False
def unarchive_all_chats_by_user_id(self, user_id: str) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(user_id=user_id).update({"archived": False})
db.commit()
return True
except Exception:
return False
def update_chat_share_id_by_id(
self, id: str, share_id: Optional[str]
) -> Optional[ChatModel]:
@ -549,7 +360,6 @@ class ChatTable:
with get_db() as db:
chat = db.get(Chat, id)
chat.archived = not chat.archived
chat.folder_id = None
chat.updated_at = int(time.time())
db.commit()
db.refresh(chat)
@ -585,10 +395,7 @@ class ChatTable:
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by and direction:
if not getattr(Chat, order_by, None):
raise ValueError("Invalid order_by field")
if order_by and direction and getattr(Chat, order_by):
if direction.lower() == "asc":
query = query.order_by(getattr(Chat, order_by).asc())
elif direction.lower() == "desc":
@ -649,19 +456,12 @@ class ChatTable:
self,
user_id: str,
include_archived: bool = False,
include_folders: bool = False,
include_pinned: bool = False,
skip: Optional[int] = None,
limit: Optional[int] = None,
) -> list[ChatTitleIdResponse]:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id)
if not include_folders:
query = query.filter_by(folder_id=None)
if not include_pinned:
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
if not include_archived:
query = query.filter_by(archived=False)
@ -706,15 +506,8 @@ class ChatTable:
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat_item = db.get(Chat, id)
if chat_item is None:
return None
if self._sanitize_chat_row(chat_item):
db.commit()
db.refresh(chat_item)
return ChatModel.model_validate(chat_item)
chat = db.get(Chat, id)
return ChatModel.model_validate(chat)
except Exception:
return None
@ -749,31 +542,14 @@ class ChatTable:
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id(
self, user_id: str, skip: Optional[int] = None, limit: Optional[int] = None
) -> ChatListResponse:
def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
query = (
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id)
.order_by(Chat.updated_at.desc())
)
total = query.count()
if skip is not None:
query = query.offset(skip)
if limit is not None:
query = query.limit(limit)
all_chats = query.all()
return ChatListResponse(
**{
"items": [ChatModel.model_validate(chat) for chat in all_chats],
"total": total,
}
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
@ -804,7 +580,7 @@ class ChatTable:
"""
Filters chats based on a search query using Python, allowing pagination using skip and limit.
"""
search_text = sanitize_text_for_db(search_text).lower().strip()
search_text = search_text.lower().strip()
if not search_text:
return self.get_chat_list_by_user_id(
@ -820,45 +596,8 @@ class ChatTable:
if word.startswith("tag:")
]
# Extract folder names - handle spaces and case insensitivity
folders = Folders.search_folders_by_names(
user_id,
[
word.replace("folder:", "")
for word in search_text_words
if word.startswith("folder:")
],
)
folder_ids = [folder.id for folder in folders]
is_pinned = None
if "pinned:true" in search_text_words:
is_pinned = True
elif "pinned:false" in search_text_words:
is_pinned = False
is_archived = None
if "archived:true" in search_text_words:
is_archived = True
elif "archived:false" in search_text_words:
is_archived = False
is_shared = None
if "shared:true" in search_text_words:
is_shared = True
elif "shared:false" in search_text_words:
is_shared = False
search_text_words = [
word
for word in search_text_words
if (
not word.startswith("tag:")
and not word.startswith("folder:")
and not word.startswith("pinned:")
and not word.startswith("archived:")
and not word.startswith("shared:")
)
word for word in search_text_words if not word.startswith("tag:")
]
search_text = " ".join(search_text_words)
@ -866,41 +605,30 @@ class ChatTable:
with get_db() as db:
query = db.query(Chat).filter(Chat.user_id == user_id)
if is_archived is not None:
query = query.filter(Chat.archived == is_archived)
elif not include_archived:
if not include_archived:
query = query.filter(Chat.archived == False)
if is_pinned is not None:
query = query.filter(Chat.pinned == is_pinned)
if is_shared is not None:
if is_shared:
query = query.filter(Chat.share_id.isnot(None))
else:
query = query.filter(Chat.share_id.is_(None))
if folder_ids:
query = query.filter(Chat.folder_id.in_(folder_ids))
query = query.order_by(Chat.updated_at.desc())
# Check if the database dialect is either 'sqlite' or 'postgresql'
dialect_name = db.bind.dialect.name
if dialect_name == "sqlite":
# SQLite case: using JSON1 extension for JSON searching
sqlite_content_sql = (
"EXISTS ("
" SELECT 1 "
" FROM json_each(Chat.chat, '$.messages') AS message "
" WHERE LOWER(message.value->>'content') LIKE '%' || :content_key || '%'"
")"
)
sqlite_content_clause = text(sqlite_content_sql)
query = query.filter(
or_(
Chat.title.ilike(bindparam("title_key")), sqlite_content_clause
).params(title_key=f"%{search_text}%", content_key=search_text)
(
Chat.title.ilike(
f"%{search_text}%"
) # Case-insensitive search in title
| text(
"""
EXISTS (
SELECT 1
FROM json_each(Chat.chat, '$.messages') AS message
WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
)
"""
)
).params(search_text=search_text)
)
# Check if there are any tags to filter, it should have all the tags
@ -934,32 +662,23 @@ class ChatTable:
)
elif dialect_name == "postgresql":
# PostgreSQL doesn't allow null bytes in text. We filter those out by checking
# the JSON representation for \u0000 before attempting text extraction
# Safety filter: JSON field must not contain \u0000
query = query.filter(text("Chat.chat::text NOT LIKE '%\\\\u0000%'"))
# Safety filter: title must not contain actual null bytes
query = query.filter(text("Chat.title::text NOT LIKE '%\\x00%'"))
postgres_content_sql = """
EXISTS (
SELECT 1
FROM json_array_elements(Chat.chat->'messages') AS message
WHERE json_typeof(message->'content') = 'string'
AND LOWER(message->>'content') LIKE '%' || :content_key || '%'
)
"""
postgres_content_clause = text(postgres_content_sql)
# PostgreSQL relies on proper JSON query for search
query = query.filter(
or_(
Chat.title.ilike(bindparam("title_key")),
postgres_content_clause,
)
).params(title_key=f"%{search_text}%", content_key=search_text.lower())
(
Chat.title.ilike(
f"%{search_text}%"
) # Case-insensitive search in title
| text(
"""
EXISTS (
SELECT 1
FROM json_array_elements(Chat.chat->'messages') AS message
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
)
"""
)
).params(search_text=search_text)
)
# Check if there are any tags to filter, it should have all the tags
if "none" in tag_ids:
@ -1004,7 +723,7 @@ class ChatTable:
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_folder_id_and_user_id(
self, folder_id: str, user_id: str, skip: int = 0, limit: int = 60
self, folder_id: str, user_id: str
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id)
@ -1013,11 +732,6 @@ class ChatTable:
query = query.order_by(Chat.updated_at.desc())
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
@ -1147,16 +861,6 @@ class ChatTable:
return count
def count_chats_by_folder_id_and_user_id(self, folder_id: str, user_id: str) -> int:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id)
query = query.filter_by(folder_id=folder_id)
count = query.count()
log.info(f"Count of chats for folder '{folder_id}': {count}")
return count
def delete_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
) -> bool:
@ -1234,20 +938,6 @@ class ChatTable:
except Exception:
return False
def move_chats_by_user_id_and_folder_id(
self, user_id: str, folder_id: str, new_folder_id: Optional[str]
) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).update(
{"folder_id": new_folder_id}
)
db.commit()
return True
except Exception:
return False
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try:
with get_db() as db:
@ -1261,93 +951,5 @@ class ChatTable:
except Exception:
return False
def insert_chat_files(
self, chat_id: str, message_id: str, file_ids: list[str], user_id: str
) -> Optional[list[ChatFileModel]]:
if not file_ids:
return None
chat_message_file_ids = [
item.id
for item in self.get_chat_files_by_chat_id_and_message_id(
chat_id, message_id
)
]
# Remove duplicates and existing file_ids
file_ids = list(
set(
[
file_id
for file_id in file_ids
if file_id and file_id not in chat_message_file_ids
]
)
)
if not file_ids:
return None
try:
with get_db() as db:
now = int(time.time())
chat_files = [
ChatFileModel(
id=str(uuid.uuid4()),
user_id=user_id,
chat_id=chat_id,
message_id=message_id,
file_id=file_id,
created_at=now,
updated_at=now,
)
for file_id in file_ids
]
results = [
ChatFile(**chat_file.model_dump()) for chat_file in chat_files
]
db.add_all(results)
db.commit()
return chat_files
except Exception:
return None
def get_chat_files_by_chat_id_and_message_id(
self, chat_id: str, message_id: str
) -> list[ChatFileModel]:
with get_db() as db:
all_chat_files = (
db.query(ChatFile)
.filter_by(chat_id=chat_id, message_id=message_id)
.order_by(ChatFile.created_at.asc())
.all()
)
return [
ChatFileModel.model_validate(chat_file) for chat_file in all_chat_files
]
def delete_chat_file(self, chat_id: str, file_id: str) -> bool:
try:
with get_db() as db:
db.query(ChatFile).filter_by(chat_id=chat_id, file_id=file_id).delete()
db.commit()
return True
except Exception:
return False
def get_shared_chats_by_file_id(self, file_id: str) -> list[ChatModel]:
with get_db() as db:
# Join Chat and ChatFile tables to get shared chats associated with the file_id
all_chats = (
db.query(Chat)
.join(ChatFile, Chat.id == ChatFile.chat_id)
.filter(ChatFile.file_id == file_id, Chat.share_id.isnot(None))
.all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
Chats = ChatTable()

View file

@ -4,12 +4,14 @@ import uuid
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.users import User
from open_webui.models.chats import Chats
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
@ -19,7 +21,7 @@ log = logging.getLogger(__name__)
class Feedback(Base):
__tablename__ = "feedback"
id = Column(Text, primary_key=True, unique=True)
id = Column(Text, primary_key=True)
user_id = Column(Text)
version = Column(BigInteger, default=0)
type = Column(Text)
@ -60,13 +62,6 @@ class FeedbackResponse(BaseModel):
updated_at: int
class FeedbackIdResponse(BaseModel):
id: str
user_id: str
created_at: int
updated_at: int
class RatingData(BaseModel):
rating: Optional[str | int] = None
model_id: Optional[str] = None
@ -97,28 +92,6 @@ class FeedbackForm(BaseModel):
model_config = ConfigDict(extra="allow")
class UserResponse(BaseModel):
id: str
name: str
email: str
role: str = "pending"
last_active_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
class FeedbackUserResponse(FeedbackResponse):
user: Optional[UserResponse] = None
class FeedbackListResponse(BaseModel):
items: list[FeedbackUserResponse]
total: int
class FeedbackTable:
def insert_new_feedback(
self, user_id: str, form_data: FeedbackForm
@ -170,70 +143,6 @@ class FeedbackTable:
except Exception:
return None
def get_feedback_items(
self, filter: dict = {}, skip: int = 0, limit: int = 30
) -> FeedbackListResponse:
with get_db() as db:
query = db.query(Feedback, User).join(User, Feedback.user_id == User.id)
if filter:
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by == "username":
if direction == "asc":
query = query.order_by(User.name.asc())
else:
query = query.order_by(User.name.desc())
elif order_by == "model_id":
# it's stored in feedback.data['model_id']
if direction == "asc":
query = query.order_by(
Feedback.data["model_id"].as_string().asc()
)
else:
query = query.order_by(
Feedback.data["model_id"].as_string().desc()
)
elif order_by == "rating":
# it's stored in feedback.data['rating']
if direction == "asc":
query = query.order_by(
Feedback.data["rating"].as_string().asc()
)
else:
query = query.order_by(
Feedback.data["rating"].as_string().desc()
)
elif order_by == "updated_at":
if direction == "asc":
query = query.order_by(Feedback.updated_at.asc())
else:
query = query.order_by(Feedback.updated_at.desc())
else:
query = query.order_by(Feedback.created_at.desc())
# Count BEFORE pagination
total = query.count()
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
items = query.all()
feedbacks = []
for feedback, user in items:
feedback_model = FeedbackModel.model_validate(feedback)
user_model = UserResponse.model_validate(user)
feedbacks.append(
FeedbackUserResponse(**feedback_model.model_dump(), user=user_model)
)
return FeedbackListResponse(items=feedbacks, total=total)
def get_all_feedbacks(self) -> list[FeedbackModel]:
with get_db() as db:
return [

View file

@ -3,10 +3,12 @@ import time
from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Files DB Schema
@ -15,7 +17,7 @@ log = logging.getLogger(__name__)
class File(Base):
__tablename__ = "file"
id = Column(String, primary_key=True, unique=True)
id = Column(String, primary_key=True)
user_id = Column(String)
hash = Column(Text, nullable=True)
@ -80,8 +82,7 @@ class FileModelResponse(BaseModel):
class FileMetadataResponse(BaseModel):
id: str
hash: Optional[str] = None
meta: Optional[dict] = None
meta: dict
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
@ -96,17 +97,6 @@ class FileForm(BaseModel):
access_control: Optional[dict] = None
class FileUpdateForm(BaseModel):
hash: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
class FileListResponse(BaseModel):
items: list[FileModel]
total: int
class FilesTable:
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
with get_db() as db:
@ -140,24 +130,12 @@ class FilesTable:
except Exception:
return None
def get_file_by_id_and_user_id(self, id: str, user_id: str) -> Optional[FileModel]:
with get_db() as db:
try:
file = db.query(File).filter_by(id=id, user_id=user_id).first()
if file:
return FileModel.model_validate(file)
else:
return None
except Exception:
return None
def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]:
with get_db() as db:
try:
file = db.get(File, id)
return FileMetadataResponse(
id=file.id,
hash=file.hash,
meta=file.meta,
created_at=file.created_at,
updated_at=file.updated_at,
@ -169,15 +147,6 @@ class FilesTable:
with get_db() as db:
return [FileModel.model_validate(file) for file in db.query(File).all()]
def check_access_by_user_id(self, id, user_id, permission="write") -> bool:
file = self.get_file_by_id(id)
if not file:
return False
if file.user_id == user_id:
return True
# Implement additional access control logic here as needed
return False
def get_files_by_ids(self, ids: list[str]) -> list[FileModel]:
with get_db() as db:
return [
@ -193,14 +162,11 @@ class FilesTable:
return [
FileMetadataResponse(
id=file.id,
hash=file.hash,
meta=file.meta,
created_at=file.created_at,
updated_at=file.updated_at,
)
for file in db.query(
File.id, File.hash, File.meta, File.created_at, File.updated_at
)
for file in db.query(File)
.filter(File.id.in_(ids))
.order_by(File.updated_at.desc())
.all()
@ -213,35 +179,11 @@ class FilesTable:
for file in db.query(File).filter_by(user_id=user_id).all()
]
def update_file_by_id(
self, id: str, form_data: FileUpdateForm
) -> Optional[FileModel]:
with get_db() as db:
try:
file = db.query(File).filter_by(id=id).first()
if form_data.hash is not None:
file.hash = form_data.hash
if form_data.data is not None:
file.data = {**(file.data if file.data else {}), **form_data.data}
if form_data.meta is not None:
file.meta = {**(file.meta if file.meta else {}), **form_data.meta}
file.updated_at = int(time.time())
db.commit()
return FileModel.model_validate(file)
except Exception as e:
log.exception(f"Error updating file completely by id: {e}")
return None
def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]:
with get_db() as db:
try:
file = db.query(File).filter_by(id=id).first()
file.hash = hash
file.updated_at = int(time.time())
db.commit()
return FileModel.model_validate(file)
@ -253,7 +195,6 @@ class FilesTable:
try:
file = db.query(File).filter_by(id=id).first()
file.data = {**(file.data if file.data else {}), **data}
file.updated_at = int(time.time())
db.commit()
return FileModel.model_validate(file)
except Exception as e:
@ -265,7 +206,6 @@ class FilesTable:
try:
file = db.query(File).filter_by(id=id).first()
file.meta = {**(file.meta if file.meta else {}), **meta}
file.updated_at = int(time.time())
db.commit()
return FileModel.model_validate(file)
except Exception:

View file

@ -2,16 +2,18 @@ import logging
import time
import uuid
from typing import Optional
import re
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean, func
from open_webui.internal.db import Base, get_db
from open_webui.models.chats import Chats
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
from open_webui.utils.access_control import get_permissions
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
@ -21,13 +23,12 @@ log = logging.getLogger(__name__)
class Folder(Base):
__tablename__ = "folder"
id = Column(Text, primary_key=True, unique=True)
id = Column(Text, primary_key=True)
parent_id = Column(Text, nullable=True)
user_id = Column(Text)
name = Column(Text)
items = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
data = Column(JSON, nullable=True)
is_expanded = Column(Boolean, default=False)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
@ -40,7 +41,6 @@ class FolderModel(BaseModel):
name: str
items: Optional[dict] = None
meta: Optional[dict] = None
data: Optional[dict] = None
is_expanded: bool = False
created_at: int
updated_at: int
@ -48,20 +48,6 @@ class FolderModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
class FolderMetadataResponse(BaseModel):
icon: Optional[str] = None
class FolderNameIdResponse(BaseModel):
id: str
name: str
meta: Optional[FolderMetadataResponse] = None
parent_id: Optional[str] = None
is_expanded: bool = False
created_at: int
updated_at: int
####################
# Forms
####################
@ -69,21 +55,12 @@ class FolderNameIdResponse(BaseModel):
class FolderForm(BaseModel):
name: str
data: Optional[dict] = None
meta: Optional[dict] = None
model_config = ConfigDict(extra="allow")
class FolderUpdateForm(BaseModel):
name: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
model_config = ConfigDict(extra="allow")
class FolderTable:
def insert_new_folder(
self, user_id: str, form_data: FolderForm, parent_id: Optional[str] = None
self, user_id: str, name: str, parent_id: Optional[str] = None
) -> Optional[FolderModel]:
with get_db() as db:
id = str(uuid.uuid4())
@ -91,7 +68,7 @@ class FolderTable:
**{
"id": id,
"user_id": user_id,
**(form_data.model_dump(exclude_unset=True) or {}),
"name": name,
"parent_id": parent_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
@ -126,7 +103,7 @@ class FolderTable:
def get_children_folders_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[list[FolderModel]]:
) -> Optional[FolderModel]:
try:
with get_db() as db:
folders = []
@ -210,8 +187,8 @@ class FolderTable:
log.error(f"update_folder: {e}")
return
def update_folder_by_id_and_user_id(
self, id: str, user_id: str, form_data: FolderUpdateForm
def update_folder_name_by_id_and_user_id(
self, id: str, user_id: str, name: str
) -> Optional[FolderModel]:
try:
with get_db() as db:
@ -220,35 +197,18 @@ class FolderTable:
if not folder:
return None
form_data = form_data.model_dump(exclude_unset=True)
existing_folder = (
db.query(Folder)
.filter_by(
name=form_data.get("name"),
parent_id=folder.parent_id,
user_id=user_id,
)
.filter_by(name=name, parent_id=folder.parent_id, user_id=user_id)
.first()
)
if existing_folder and existing_folder.id != id:
if existing_folder:
return None
folder.name = form_data.get("name", folder.name)
if "data" in form_data:
folder.data = {
**(folder.data or {}),
**form_data["data"],
}
if "meta" in form_data:
folder.meta = {
**(folder.meta or {}),
**form_data["meta"],
}
folder.name = name
folder.updated_at = int(time.time())
db.commit()
return FolderModel.model_validate(folder)
@ -276,15 +236,18 @@ class FolderTable:
log.error(f"update_folder: {e}")
return
def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> list[str]:
def delete_folder_by_id_and_user_id(
self, id: str, user_id: str, delete_chats=True
) -> bool:
try:
folder_ids = []
with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return folder_ids
return False
folder_ids.append(folder.id)
if delete_chats:
# Delete all chats in the folder
Chats.delete_chats_by_user_id_and_folder_id(user_id, folder.id)
# Delete all children folders
def delete_children(folder):
@ -292,9 +255,12 @@ class FolderTable:
folder.id, user_id
)
for folder_child in folder_children:
if delete_chats:
Chats.delete_chats_by_user_id_and_folder_id(
user_id, folder_child.id
)
delete_children(folder_child)
folder_ids.append(folder_child.id)
folder = db.query(Folder).filter_by(id=folder_child.id).first()
db.delete(folder)
@ -303,62 +269,10 @@ class FolderTable:
delete_children(folder)
db.delete(folder)
db.commit()
return folder_ids
return True
except Exception as e:
log.error(f"delete_folder: {e}")
return []
def normalize_folder_name(self, name: str) -> str:
# Replace _ and space with a single space, lower case, collapse multiple spaces
name = re.sub(r"[\s_]+", " ", name)
return name.strip().lower()
def search_folders_by_names(
self, user_id: str, queries: list[str]
) -> list[FolderModel]:
"""
Search for folders for a user where the name matches any of the queries, treating _ and space as equivalent, case-insensitive.
"""
normalized_queries = [self.normalize_folder_name(q) for q in queries]
if not normalized_queries:
return []
results = {}
with get_db() as db:
folders = db.query(Folder).filter_by(user_id=user_id).all()
for folder in folders:
if self.normalize_folder_name(folder.name) in normalized_queries:
results[folder.id] = FolderModel.model_validate(folder)
# get children folders
children = self.get_children_folders_by_id_and_user_id(
folder.id, user_id
)
for child in children:
results[child.id] = child
# Return the results as a list
if not results:
return []
else:
results = list(results.values())
return results
def search_folders_by_name_contains(
self, user_id: str, query: str
) -> list[FolderModel]:
"""
Partial match: normalized name contains (as substring) the normalized query.
"""
normalized_query = self.normalize_folder_name(query)
results = []
with get_db() as db:
folders = db.query(Folder).filter_by(user_id=user_id).all()
for folder in folders:
norm_name = self.normalize_folder_name(folder.name)
if normalized_query in norm_name:
results.append(FolderModel.model_validate(folder))
return results
return False
Folders = FolderTable()

View file

@ -3,11 +3,13 @@ import time
from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.users import Users, UserModel
from open_webui.models.users import Users
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, Index
from sqlalchemy import BigInteger, Boolean, Column, String, Text
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Functions DB Schema
@ -17,7 +19,7 @@ log = logging.getLogger(__name__)
class Function(Base):
__tablename__ = "function"
id = Column(String, primary_key=True, unique=True)
id = Column(String, primary_key=True)
user_id = Column(String)
name = Column(Text)
type = Column(Text)
@ -29,13 +31,10 @@ class Function(Base):
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
__table_args__ = (Index("is_global_idx", "is_global"),)
class FunctionMeta(BaseModel):
description: Optional[str] = None
manifest: Optional[dict] = {}
model_config = ConfigDict(extra="allow")
class FunctionModel(BaseModel):
@ -53,31 +52,11 @@ class FunctionModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
class FunctionWithValvesModel(BaseModel):
id: str
user_id: str
name: str
type: str
content: str
meta: FunctionMeta
valves: Optional[dict] = None
is_active: bool = False
is_global: bool = False
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class FunctionUserResponse(FunctionModel):
user: Optional[UserModel] = None
class FunctionResponse(BaseModel):
id: str
user_id: str
@ -130,8 +109,8 @@ class FunctionsTable:
return None
def sync_functions(
self, user_id: str, functions: list[FunctionWithValvesModel]
) -> list[FunctionWithValvesModel]:
self, user_id: str, functions: list[FunctionModel]
) -> list[FunctionModel]:
# Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present.
try:
with get_db() as db:
@ -185,48 +164,19 @@ class FunctionsTable:
except Exception:
return None
def get_functions(
self, active_only=False, include_valves=False
) -> list[FunctionModel | FunctionWithValvesModel]:
def get_functions(self, active_only=False) -> list[FunctionModel]:
with get_db() as db:
if active_only:
functions = db.query(Function).filter_by(is_active=True).all()
else:
functions = db.query(Function).all()
if include_valves:
return [
FunctionWithValvesModel.model_validate(function)
for function in functions
FunctionModel.model_validate(function)
for function in db.query(Function).filter_by(is_active=True).all()
]
else:
return [
FunctionModel.model_validate(function) for function in functions
FunctionModel.model_validate(function)
for function in db.query(Function).all()
]
def get_function_list(self) -> list[FunctionUserResponse]:
with get_db() as db:
functions = db.query(Function).order_by(Function.updated_at.desc()).all()
user_ids = list(set(func.user_id for func in functions))
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
users_dict = {user.id: user for user in users}
return [
FunctionUserResponse.model_validate(
{
**FunctionModel.model_validate(func).model_dump(),
"user": (
users_dict.get(func.user_id).model_dump()
if func.user_id in users_dict
else None
),
}
)
for func in functions
]
def get_functions_by_type(
self, type: str, active_only=False
) -> list[FunctionModel]:
@ -285,29 +235,6 @@ class FunctionsTable:
except Exception:
return None
def update_function_metadata_by_id(
self, id: str, metadata: dict
) -> Optional[FunctionModel]:
with get_db() as db:
try:
function = db.get(Function, id)
if function:
if function.meta:
function.meta = {**function.meta, **metadata}
else:
function.meta = metadata
function.updated_at = int(time.time())
db.commit()
db.refresh(function)
return self.get_function_by_id(id)
else:
return None
except Exception as e:
log.exception(f"Error updating function metadata by id {id}: {e}")
return None
def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[dict]:
@ -323,7 +250,9 @@ class FunctionsTable:
return user_settings["functions"]["valves"].get(id, {})
except Exception as e:
log.exception(f"Error getting user values by id {id} and user id {user_id}")
log.exception(
f"Error getting user values by id {id} and user id {user_id}: {e}"
)
return None
def update_user_valves_by_id_and_user_id(

View file

@ -5,26 +5,17 @@ from typing import Optional
import uuid
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.files import FileMetadataResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import (
BigInteger,
Column,
String,
Text,
JSON,
and_,
func,
ForeignKey,
cast,
or_,
)
from sqlalchemy import BigInteger, Column, String, Text, JSON, func
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# UserGroup DB Schema
@ -44,12 +35,14 @@ class Group(Base):
meta = Column(JSON, nullable=True)
permissions = Column(JSON, nullable=True)
user_ids = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class GroupModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
@ -60,62 +53,38 @@ class GroupModel(BaseModel):
meta: Optional[dict] = None
permissions: Optional[dict] = None
user_ids: list[str] = []
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
class GroupMember(Base):
__tablename__ = "group_member"
id = Column(Text, unique=True, primary_key=True)
group_id = Column(
Text,
ForeignKey("group.id", ondelete="CASCADE"),
nullable=False,
)
user_id = Column(Text, nullable=False)
created_at = Column(BigInteger, nullable=True)
updated_at = Column(BigInteger, nullable=True)
class GroupMemberModel(BaseModel):
id: str
group_id: str
user_id: str
created_at: Optional[int] = None # timestamp in epoch
updated_at: Optional[int] = None # timestamp in epoch
####################
# Forms
####################
class GroupResponse(GroupModel):
member_count: Optional[int] = None
class GroupResponse(BaseModel):
id: str
user_id: str
name: str
description: str
permissions: Optional[dict] = None
data: Optional[dict] = None
meta: Optional[dict] = None
user_ids: list[str] = []
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
class GroupForm(BaseModel):
name: str
description: str
permissions: Optional[dict] = None
data: Optional[dict] = None
class UserIdsForm(BaseModel):
user_ids: Optional[list[str]] = None
class GroupUpdateForm(GroupForm):
pass
class GroupListResponse(BaseModel):
items: list[GroupResponse] = []
total: int = 0
user_ids: Optional[list[str]] = None
class GroupTable:
@ -146,94 +115,24 @@ class GroupTable:
except Exception:
return None
def get_all_groups(self) -> list[GroupModel]:
def get_groups(self) -> list[GroupModel]:
with get_db() as db:
groups = db.query(Group).order_by(Group.updated_at.desc()).all()
return [GroupModel.model_validate(group) for group in groups]
def get_groups(self, filter) -> list[GroupResponse]:
with get_db() as db:
query = db.query(Group)
if filter:
if "query" in filter:
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
if "member_id" in filter:
query = query.join(
GroupMember, GroupMember.group_id == Group.id
).filter(GroupMember.user_id == filter["member_id"])
if "share" in filter:
share_value = filter["share"]
json_share = Group.data["config"]["share"].as_boolean()
if share_value:
query = query.filter(
or_(
Group.data.is_(None),
json_share.is_(None),
json_share == True,
)
)
else:
query = query.filter(
and_(Group.data.isnot(None), json_share == False)
)
groups = query.order_by(Group.updated_at.desc()).all()
return [
GroupResponse.model_validate(
{
**GroupModel.model_validate(group).model_dump(),
"member_count": self.get_group_member_count_by_id(group.id),
}
)
for group in groups
GroupModel.model_validate(group)
for group in db.query(Group).order_by(Group.updated_at.desc()).all()
]
def search_groups(
self, filter: Optional[dict] = None, skip: int = 0, limit: int = 30
) -> GroupListResponse:
with get_db() as db:
query = db.query(Group)
if filter:
if "query" in filter:
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
if "member_id" in filter:
query = query.join(
GroupMember, GroupMember.group_id == Group.id
).filter(GroupMember.user_id == filter["member_id"])
if "share" in filter:
# 'share' is stored in data JSON, support both sqlite and postgres
share_value = filter["share"]
print("Filtering by share:", share_value)
query = query.filter(
Group.data.op("->>")("share") == str(share_value)
)
total = query.count()
query = query.order_by(Group.updated_at.desc())
groups = query.offset(skip).limit(limit).all()
return {
"items": [
GroupResponse.model_validate(
**GroupModel.model_validate(group).model_dump(),
member_count=self.get_group_member_count_by_id(group.id),
)
for group in groups
],
"total": total,
}
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
with get_db() as db:
return [
GroupModel.model_validate(group)
for group in db.query(Group)
.join(GroupMember, GroupMember.group_id == Group.id)
.filter(GroupMember.user_id == user_id)
.filter(
func.json_array_length(Group.user_ids) > 0
) # Ensure array exists
.filter(
Group.user_ids.cast(String).like(f'%"{user_id}"%')
) # String-based check
.order_by(Group.updated_at.desc())
.all()
]
@ -246,63 +145,12 @@ class GroupTable:
except Exception:
return None
def get_group_user_ids_by_id(self, id: str) -> Optional[list[str]]:
with get_db() as db:
members = (
db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
)
if not members:
return None
return [m[0] for m in members]
def get_group_user_ids_by_ids(self, group_ids: list[str]) -> dict[str, list[str]]:
with get_db() as db:
members = (
db.query(GroupMember.group_id, GroupMember.user_id)
.filter(GroupMember.group_id.in_(group_ids))
.all()
)
group_user_ids: dict[str, list[str]] = {
group_id: [] for group_id in group_ids
}
for group_id, user_id in members:
group_user_ids[group_id].append(user_id)
return group_user_ids
def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str]) -> None:
with get_db() as db:
# Delete existing members
db.query(GroupMember).filter(GroupMember.group_id == group_id).delete()
# Insert new members
now = int(time.time())
new_members = [
GroupMember(
id=str(uuid.uuid4()),
group_id=group_id,
user_id=user_id,
created_at=now,
updated_at=now,
)
for user_id in user_ids
]
db.add_all(new_members)
db.commit()
def get_group_member_count_by_id(self, id: str) -> int:
with get_db() as db:
count = (
db.query(func.count(GroupMember.user_id))
.filter(GroupMember.group_id == id)
.scalar()
)
return count if count else 0
def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
group = self.get_group_by_id(id)
if group:
return group.user_ids
else:
return None
def update_group_by_id(
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
@ -343,189 +191,59 @@ class GroupTable:
def remove_user_from_all_groups(self, user_id: str) -> bool:
with get_db() as db:
try:
# Find all groups the user belongs to
groups = (
db.query(Group)
.join(GroupMember, GroupMember.group_id == Group.id)
.filter(GroupMember.user_id == user_id)
.all()
)
groups = self.get_groups_by_member_id(user_id)
# Remove the user from each group
for group in groups:
db.query(GroupMember).filter(
GroupMember.group_id == group.id, GroupMember.user_id == user_id
).delete()
group.user_ids.remove(user_id)
db.query(Group).filter_by(id=group.id).update(
{"updated_at": int(time.time())}
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
)
db.commit()
db.commit()
return True
except Exception:
db.rollback()
return False
def create_groups_by_group_names(
def sync_user_groups_by_group_names(
self, user_id: str, group_names: list[str]
) -> list[GroupModel]:
# check for existing groups
existing_groups = self.get_all_groups()
existing_group_names = {group.name for group in existing_groups}
new_groups = []
with get_db() as db:
for group_name in group_names:
if group_name not in existing_group_names:
new_group = GroupModel(
id=str(uuid.uuid4()),
user_id=user_id,
name=group_name,
description="",
created_at=int(time.time()),
updated_at=int(time.time()),
)
try:
result = Group(**new_group.model_dump())
db.add(result)
db.commit()
db.refresh(result)
new_groups.append(GroupModel.model_validate(result))
except Exception as e:
log.exception(e)
continue
return new_groups
def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
) -> bool:
with get_db() as db:
try:
now = int(time.time())
groups = db.query(Group).filter(Group.name.in_(group_names)).all()
group_ids = [group.id for group in groups]
# 1. Groups that SHOULD contain the user
target_groups = (
db.query(Group).filter(Group.name.in_(group_names)).all()
)
target_group_ids = {g.id for g in target_groups}
# Remove user from groups not in the new list
existing_groups = self.get_groups_by_member_id(user_id)
# 2. Groups the user is CURRENTLY in
existing_group_ids = {
g.id
for g in db.query(Group)
.join(GroupMember, GroupMember.group_id == Group.id)
.filter(GroupMember.user_id == user_id)
.all()
}
# 3. Determine adds + removals
groups_to_add = target_group_ids - existing_group_ids
groups_to_remove = existing_group_ids - target_group_ids
# 4. Remove in one bulk delete
if groups_to_remove:
db.query(GroupMember).filter(
GroupMember.user_id == user_id,
GroupMember.group_id.in_(groups_to_remove),
).delete(synchronize_session=False)
db.query(Group).filter(Group.id.in_(groups_to_remove)).update(
{"updated_at": now}, synchronize_session=False
)
# 5. Bulk insert missing memberships
for group_id in groups_to_add:
db.add(
GroupMember(
id=str(uuid.uuid4()),
group_id=group_id,
user_id=user_id,
created_at=now,
updated_at=now,
for group in existing_groups:
if group.id not in group_ids:
group.user_ids.remove(user_id)
db.query(Group).filter_by(id=group.id).update(
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
)
)
if groups_to_add:
db.query(Group).filter(Group.id.in_(groups_to_add)).update(
{"updated_at": now}, synchronize_session=False
)
# Add user to new groups
for group in groups:
if user_id not in group.user_ids:
group.user_ids.append(user_id)
db.query(Group).filter_by(id=group.id).update(
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
)
db.commit()
return True
except Exception as e:
log.exception(e)
db.rollback()
return False
def add_users_to_group(
self, id: str, user_ids: Optional[list[str]] = None
) -> Optional[GroupModel]:
try:
with get_db() as db:
group = db.query(Group).filter_by(id=id).first()
if not group:
return None
now = int(time.time())
for user_id in user_ids or []:
try:
db.add(
GroupMember(
id=str(uuid.uuid4()),
group_id=id,
user_id=user_id,
created_at=now,
updated_at=now,
)
)
db.flush() # Detect unique constraint violation early
except Exception:
db.rollback() # Clear failed INSERT
db.begin() # Start a new transaction
continue # Duplicate → ignore
group.updated_at = now
db.commit()
db.refresh(group)
return GroupModel.model_validate(group)
except Exception as e:
log.exception(e)
return None
def remove_users_from_group(
self, id: str, user_ids: Optional[list[str]] = None
) -> Optional[GroupModel]:
try:
with get_db() as db:
group = db.query(Group).filter_by(id=id).first()
if not group:
return None
if not user_ids:
return GroupModel.model_validate(group)
# Remove each user from group_member
for user_id in user_ids:
db.query(GroupMember).filter(
GroupMember.group_id == id, GroupMember.user_id == user_id
).delete()
# Update group timestamp
group.updated_at = int(time.time())
db.commit()
db.refresh(group)
return GroupModel.model_validate(group)
except Exception as e:
log.exception(e)
return None
Groups = GroupTable()

View file

@ -5,34 +5,19 @@ from typing import Optional
import uuid
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.files import (
File,
FileModel,
FileMetadataResponse,
FileModelResponse,
)
from open_webui.models.groups import Groups
from open_webui.models.users import User, UserModel, Users, UserResponse
from open_webui.models.files import FileMetadataResponse
from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import (
BigInteger,
Column,
ForeignKey,
String,
Text,
JSON,
UniqueConstraint,
or_,
)
from sqlalchemy import BigInteger, Column, String, Text, JSON
from open_webui.utils.access_control import has_access
from open_webui.utils.db.access_control import has_permission
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Knowledge DB Schema
@ -48,7 +33,9 @@ class Knowledge(Base):
name = Column(Text)
description = Column(Text)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
access_control = Column(JSON, nullable=True) # Controls data access levels.
# Defines access control rules for this entry.
# - `None`: Public access, available to all users with the "user" role.
@ -79,6 +66,7 @@ class KnowledgeModel(BaseModel):
name: str
description: str
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
@ -87,42 +75,11 @@ class KnowledgeModel(BaseModel):
updated_at: int # timestamp in epoch
class KnowledgeFile(Base):
__tablename__ = "knowledge_file"
id = Column(Text, unique=True, primary_key=True)
knowledge_id = Column(
Text, ForeignKey("knowledge.id", ondelete="CASCADE"), nullable=False
)
file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False)
user_id = Column(Text, nullable=False)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
__table_args__ = (
UniqueConstraint(
"knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file"
),
)
class KnowledgeFileModel(BaseModel):
id: str
knowledge_id: str
file_id: str
user_id: str
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class KnowledgeUserModel(KnowledgeModel):
user: Optional[UserResponse] = None
@ -132,29 +89,16 @@ class KnowledgeResponse(KnowledgeModel):
class KnowledgeUserResponse(KnowledgeUserModel):
pass
files: Optional[list[FileMetadataResponse | dict]] = None
class KnowledgeForm(BaseModel):
name: str
description: str
data: Optional[dict] = None
access_control: Optional[dict] = None
class FileUserResponse(FileModelResponse):
user: Optional[UserResponse] = None
class KnowledgeListResponse(BaseModel):
items: list[KnowledgeUserModel]
total: int
class KnowledgeFileListResponse(BaseModel):
items: list[FileUserResponse]
total: int
class KnowledgeTable:
def insert_new_knowledge(
self, user_id: str, form_data: KnowledgeForm
@ -182,21 +126,13 @@ class KnowledgeTable:
except Exception:
return None
def get_knowledge_bases(
self, skip: int = 0, limit: int = 30
) -> list[KnowledgeUserModel]:
def get_knowledge_bases(self) -> list[KnowledgeUserModel]:
with get_db() as db:
all_knowledge = (
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
)
user_ids = list(set(knowledge.user_id for knowledge in all_knowledge))
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
users_dict = {user.id: user for user in users}
knowledge_bases = []
for knowledge in all_knowledge:
user = users_dict.get(knowledge.user_id)
for knowledge in (
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
):
user = Users.get_user_by_id(knowledge.user_id)
knowledge_bases.append(
KnowledgeUserModel.model_validate(
{
@ -207,147 +143,15 @@ class KnowledgeTable:
)
return knowledge_bases
def search_knowledge_bases(
self, user_id: str, filter: dict, skip: int = 0, limit: int = 30
) -> KnowledgeListResponse:
try:
with get_db() as db:
query = db.query(Knowledge, User).outerjoin(
User, User.id == Knowledge.user_id
)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(
or_(
Knowledge.name.ilike(f"%{query_key}%"),
Knowledge.description.ilike(f"%{query_key}%"),
)
)
view_option = filter.get("view_option")
if view_option == "created":
query = query.filter(Knowledge.user_id == user_id)
elif view_option == "shared":
query = query.filter(Knowledge.user_id != user_id)
query = has_permission(db, Knowledge, query, filter)
query = query.order_by(Knowledge.updated_at.desc())
total = query.count()
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
items = query.all()
knowledge_bases = []
for knowledge_base, user in items:
knowledge_bases.append(
KnowledgeUserModel.model_validate(
{
**KnowledgeModel.model_validate(
knowledge_base
).model_dump(),
"user": (
UserModel.model_validate(user).model_dump()
if user
else None
),
}
)
)
return KnowledgeListResponse(items=knowledge_bases, total=total)
except Exception as e:
print(e)
return KnowledgeListResponse(items=[], total=0)
def search_knowledge_files(
self, filter: dict, skip: int = 0, limit: int = 30
) -> KnowledgeFileListResponse:
"""
Scalable version: search files across all knowledge bases the user has
READ access to, without loading all KBs or using large IN() lists.
"""
try:
with get_db() as db:
# Base query: join Knowledge → KnowledgeFile → File
query = (
db.query(File, User)
.join(KnowledgeFile, File.id == KnowledgeFile.file_id)
.join(Knowledge, KnowledgeFile.knowledge_id == Knowledge.id)
.outerjoin(User, User.id == KnowledgeFile.user_id)
)
# Apply access-control directly to the joined query
# This makes the database handle filtering, even with 10k+ KBs
query = has_permission(db, Knowledge, query, filter)
# Apply filename search
if filter:
q = filter.get("query")
if q:
query = query.filter(File.filename.ilike(f"%{q}%"))
# Order by file changes
query = query.order_by(File.updated_at.desc())
# Count before pagination
total = query.count()
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
rows = query.all()
items = []
for file, user in rows:
items.append(
FileUserResponse(
**FileModel.model_validate(file).model_dump(),
user=(
UserResponse(
**UserModel.model_validate(user).model_dump()
)
if user
else None
),
)
)
return KnowledgeFileListResponse(items=items, total=total)
except Exception as e:
print("search_knowledge_files error:", e)
return KnowledgeFileListResponse(items=[], total=0)
def check_access_by_user_id(self, id, user_id, permission="write") -> bool:
knowledge = self.get_knowledge_by_id(id)
if not knowledge:
return False
if knowledge.user_id == user_id:
return True
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
return has_access(user_id, permission, knowledge.access_control, user_group_ids)
def get_knowledge_bases_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[KnowledgeUserModel]:
knowledge_bases = self.get_knowledge_bases()
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
return [
knowledge_base
for knowledge_base in knowledge_bases
if knowledge_base.user_id == user_id
or has_access(
user_id, permission, knowledge_base.access_control, user_group_ids
)
or has_access(user_id, permission, knowledge_base.access_control)
]
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
@ -358,197 +162,6 @@ class KnowledgeTable:
except Exception:
return None
def get_knowledge_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[KnowledgeModel]:
knowledge = self.get_knowledge_by_id(id)
if not knowledge:
return None
if knowledge.user_id == user_id:
return knowledge
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
if has_access(user_id, "write", knowledge.access_control, user_group_ids):
return knowledge
return None
def get_knowledges_by_file_id(self, file_id: str) -> list[KnowledgeModel]:
try:
with get_db() as db:
knowledges = (
db.query(Knowledge)
.join(KnowledgeFile, Knowledge.id == KnowledgeFile.knowledge_id)
.filter(KnowledgeFile.file_id == file_id)
.all()
)
return [
KnowledgeModel.model_validate(knowledge) for knowledge in knowledges
]
except Exception:
return []
def search_files_by_id(
self,
knowledge_id: str,
user_id: str,
filter: dict,
skip: int = 0,
limit: int = 30,
) -> KnowledgeFileListResponse:
try:
with get_db() as db:
query = (
db.query(File, User)
.join(KnowledgeFile, File.id == KnowledgeFile.file_id)
.outerjoin(User, User.id == KnowledgeFile.user_id)
.filter(KnowledgeFile.knowledge_id == knowledge_id)
)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(or_(File.filename.ilike(f"%{query_key}%")))
view_option = filter.get("view_option")
if view_option == "created":
query = query.filter(KnowledgeFile.user_id == user_id)
elif view_option == "shared":
query = query.filter(KnowledgeFile.user_id != user_id)
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by == "name":
if direction == "asc":
query = query.order_by(File.filename.asc())
else:
query = query.order_by(File.filename.desc())
elif order_by == "created_at":
if direction == "asc":
query = query.order_by(File.created_at.asc())
else:
query = query.order_by(File.created_at.desc())
elif order_by == "updated_at":
if direction == "asc":
query = query.order_by(File.updated_at.asc())
else:
query = query.order_by(File.updated_at.desc())
else:
query = query.order_by(File.updated_at.desc())
else:
query = query.order_by(File.updated_at.desc())
# Count BEFORE pagination
total = query.count()
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
items = query.all()
files = []
for file, user in items:
files.append(
FileUserResponse(
**FileModel.model_validate(file).model_dump(),
user=(
UserResponse(
**UserModel.model_validate(user).model_dump()
)
if user
else None
),
)
)
return KnowledgeFileListResponse(items=files, total=total)
except Exception as e:
print(e)
return KnowledgeFileListResponse(items=[], total=0)
def get_files_by_id(self, knowledge_id: str) -> list[FileModel]:
try:
with get_db() as db:
files = (
db.query(File)
.join(KnowledgeFile, File.id == KnowledgeFile.file_id)
.filter(KnowledgeFile.knowledge_id == knowledge_id)
.all()
)
return [FileModel.model_validate(file) for file in files]
except Exception:
return []
def get_file_metadatas_by_id(self, knowledge_id: str) -> list[FileMetadataResponse]:
try:
with get_db() as db:
files = self.get_files_by_id(knowledge_id)
return [FileMetadataResponse(**file.model_dump()) for file in files]
except Exception:
return []
def add_file_to_knowledge_by_id(
self, knowledge_id: str, file_id: str, user_id: str
) -> Optional[KnowledgeFileModel]:
with get_db() as db:
knowledge_file = KnowledgeFileModel(
**{
"id": str(uuid.uuid4()),
"knowledge_id": knowledge_id,
"file_id": file_id,
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
result = KnowledgeFile(**knowledge_file.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return KnowledgeFileModel.model_validate(result)
else:
return None
except Exception:
return None
def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str) -> bool:
try:
with get_db() as db:
db.query(KnowledgeFile).filter_by(
knowledge_id=knowledge_id, file_id=file_id
).delete()
db.commit()
return True
except Exception:
return False
def reset_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
try:
with get_db() as db:
# Delete all knowledge_file entries for this knowledge_id
db.query(KnowledgeFile).filter_by(knowledge_id=id).delete()
db.commit()
# Update the knowledge entry's updated_at timestamp
db.query(Knowledge).filter_by(id=id).update(
{
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_knowledge_by_id(id=id)
except Exception as e:
log.exception(e)
return None
def update_knowledge_by_id(
self, id: str, form_data: KnowledgeForm, overwrite: bool = False
) -> Optional[KnowledgeModel]:

View file

@ -14,7 +14,7 @@ from sqlalchemy import BigInteger, Column, String, Text
class Memory(Base):
__tablename__ = "memory"
id = Column(String, primary_key=True, unique=True)
id = Column(String, primary_key=True)
user_id = Column(String)
content = Column(Text)
updated_at = Column(BigInteger)
@ -71,13 +71,9 @@ class MemoriesTable:
) -> Optional[MemoryModel]:
with get_db() as db:
try:
memory = db.get(Memory, id)
if not memory or memory.user_id != user_id:
return None
memory.content = content
memory.updated_at = int(time.time())
db.query(Memory).filter_by(id=id, user_id=user_id).update(
{"content": content, "updated_at": int(time.time())}
)
db.commit()
return self.get_memory_by_id(id)
except Exception:
@ -131,12 +127,7 @@ class MemoriesTable:
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
with get_db() as db:
try:
memory = db.get(Memory, id)
if not memory or memory.user_id != user_id:
return None
# Delete the memory
db.delete(memory)
db.query(Memory).filter_by(id=id, user_id=user_id).delete()
db.commit()
return True

View file

@ -5,11 +5,9 @@ from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags
from open_webui.models.users import Users, User, UserNameResponse
from open_webui.models.channels import Channels, ChannelMember
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
@ -21,7 +19,7 @@ from sqlalchemy.sql import exists
class MessageReaction(Base):
__tablename__ = "message_reaction"
id = Column(Text, primary_key=True, unique=True)
id = Column(Text, primary_key=True)
user_id = Column(Text)
message_id = Column(Text)
name = Column(Text)
@ -40,19 +38,13 @@ class MessageReactionModel(BaseModel):
class Message(Base):
__tablename__ = "message"
id = Column(Text, primary_key=True, unique=True)
id = Column(Text, primary_key=True)
user_id = Column(Text)
channel_id = Column(Text, nullable=True)
reply_to_id = Column(Text, nullable=True)
parent_id = Column(Text, nullable=True)
# Pins
is_pinned = Column(Boolean, nullable=False, default=False)
pinned_at = Column(BigInteger, nullable=True)
pinned_by = Column(Text, nullable=True)
content = Column(Text)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
@ -68,20 +60,14 @@ class MessageModel(BaseModel):
user_id: str
channel_id: Optional[str] = None
reply_to_id: Optional[str] = None
parent_id: Optional[str] = None
# Pins
is_pinned: bool = False
pinned_by: Optional[str] = None
pinned_at: Optional[int] = None # timestamp in epoch (time_ns)
content: str
data: Optional[dict] = None
meta: Optional[dict] = None
created_at: int # timestamp in epoch (time_ns)
updated_at: int # timestamp in epoch (time_ns)
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
####################
@ -90,9 +76,7 @@ class MessageModel(BaseModel):
class MessageForm(BaseModel):
temp_id: Optional[str] = None
content: str
reply_to_id: Optional[str] = None
parent_id: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
@ -100,36 +84,11 @@ class MessageForm(BaseModel):
class Reactions(BaseModel):
name: str
users: list[dict]
user_ids: list[str]
count: int
class MessageUserResponse(MessageModel):
user: Optional[UserNameResponse] = None
class MessageUserSlimResponse(MessageUserResponse):
data: bool | None = None
@field_validator("data", mode="before")
def convert_data_to_bool(cls, v):
# No data or not a dict → False
if not isinstance(v, dict):
return False
# True if ANY value in the dict is non-empty
return any(bool(val) for val in v.values())
class MessageReplyToResponse(MessageUserResponse):
reply_to_message: Optional[MessageUserSlimResponse] = None
class MessageWithReactionsResponse(MessageUserSlimResponse):
reactions: list[Reactions]
class MessageResponse(MessageReplyToResponse):
class MessageResponse(MessageModel):
latest_reply_at: Optional[int]
reply_count: int
reactions: list[Reactions]
@ -140,21 +99,15 @@ class MessageTable:
self, form_data: MessageForm, channel_id: str, user_id: str
) -> Optional[MessageModel]:
with get_db() as db:
channel_member = Channels.join_channel(channel_id, user_id)
id = str(uuid.uuid4())
ts = int(time.time_ns())
ts = int(time.time_ns())
message = MessageModel(
**{
"id": id,
"user_id": user_id,
"channel_id": channel_id,
"reply_to_id": form_data.reply_to_id,
"parent_id": form_data.parent_id,
"is_pinned": False,
"pinned_at": None,
"pinned_by": None,
"content": form_data.content,
"data": form_data.data,
"meta": form_data.meta,
@ -162,8 +115,8 @@ class MessageTable:
"updated_at": ts,
}
)
result = Message(**message.model_dump())
result = Message(**message.model_dump())
db.add(result)
db.commit()
db.refresh(result)
@ -175,32 +128,19 @@ class MessageTable:
if not message:
return None
reply_to_message = (
self.get_message_by_id(message.reply_to_id)
if message.reply_to_id
else None
)
reactions = self.get_reactions_by_message_id(id)
thread_replies = self.get_thread_replies_by_message_id(id)
replies = self.get_replies_by_message_id(id)
user = Users.get_user_by_id(message.user_id)
return MessageResponse.model_validate(
{
return MessageResponse(
**{
**MessageModel.model_validate(message).model_dump(),
"user": user.model_dump() if user else None,
"reply_to_message": (
reply_to_message.model_dump() if reply_to_message else None
),
"latest_reply_at": (
thread_replies[0].created_at if thread_replies else None
),
"reply_count": len(thread_replies),
"latest_reply_at": replies[0].created_at if replies else None,
"reply_count": len(replies),
"reactions": reactions,
}
)
def get_thread_replies_by_message_id(self, id: str) -> list[MessageReplyToResponse]:
def get_replies_by_message_id(self, id: str) -> list[MessageModel]:
with get_db() as db:
all_messages = (
db.query(Message)
@ -208,27 +148,7 @@ class MessageTable:
.order_by(Message.created_at.desc())
.all()
)
messages = []
for message in all_messages:
reply_to_message = (
self.get_message_by_id(message.reply_to_id)
if message.reply_to_id
else None
)
messages.append(
MessageReplyToResponse.model_validate(
{
**MessageModel.model_validate(message).model_dump(),
"reply_to_message": (
reply_to_message.model_dump()
if reply_to_message
else None
),
}
)
)
return messages
return [MessageModel.model_validate(message) for message in all_messages]
def get_reply_user_ids_by_message_id(self, id: str) -> list[str]:
with get_db() as db:
@ -239,7 +159,7 @@ class MessageTable:
def get_messages_by_channel_id(
self, channel_id: str, skip: int = 0, limit: int = 50
) -> list[MessageReplyToResponse]:
) -> list[MessageModel]:
with get_db() as db:
all_messages = (
db.query(Message)
@ -249,31 +169,11 @@ class MessageTable:
.limit(limit)
.all()
)
messages = []
for message in all_messages:
reply_to_message = (
self.get_message_by_id(message.reply_to_id)
if message.reply_to_id
else None
)
messages.append(
MessageReplyToResponse.model_validate(
{
**MessageModel.model_validate(message).model_dump(),
"reply_to_message": (
reply_to_message.model_dump()
if reply_to_message
else None
),
}
)
)
return messages
return [MessageModel.model_validate(message) for message in all_messages]
def get_messages_by_parent_id(
self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50
) -> list[MessageReplyToResponse]:
) -> list[MessageModel]:
with get_db() as db:
message = db.get(Message, parent_id)
@ -293,49 +193,6 @@ class MessageTable:
if len(all_messages) < limit:
all_messages.append(message)
messages = []
for message in all_messages:
reply_to_message = (
self.get_message_by_id(message.reply_to_id)
if message.reply_to_id
else None
)
messages.append(
MessageReplyToResponse.model_validate(
{
**MessageModel.model_validate(message).model_dump(),
"reply_to_message": (
reply_to_message.model_dump()
if reply_to_message
else None
),
}
)
)
return messages
def get_last_message_by_channel_id(self, channel_id: str) -> Optional[MessageModel]:
with get_db() as db:
message = (
db.query(Message)
.filter_by(channel_id=channel_id)
.order_by(Message.created_at.desc())
.first()
)
return MessageModel.model_validate(message) if message else None
def get_pinned_messages_by_channel_id(
self, channel_id: str, skip: int = 0, limit: int = 50
) -> list[MessageModel]:
with get_db() as db:
all_messages = (
db.query(Message)
.filter_by(channel_id=channel_id, is_pinned=True)
.order_by(Message.pinned_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return [MessageModel.model_validate(message) for message in all_messages]
def update_message_by_id(
@ -344,57 +201,17 @@ class MessageTable:
with get_db() as db:
message = db.get(Message, id)
message.content = form_data.content
message.data = {
**(message.data if message.data else {}),
**(form_data.data if form_data.data else {}),
}
message.meta = {
**(message.meta if message.meta else {}),
**(form_data.meta if form_data.meta else {}),
}
message.data = form_data.data
message.meta = form_data.meta
message.updated_at = int(time.time_ns())
db.commit()
db.refresh(message)
return MessageModel.model_validate(message) if message else None
def update_is_pinned_by_id(
self, id: str, is_pinned: bool, pinned_by: Optional[str] = None
) -> Optional[MessageModel]:
with get_db() as db:
message = db.get(Message, id)
message.is_pinned = is_pinned
message.pinned_at = int(time.time_ns()) if is_pinned else None
message.pinned_by = pinned_by if is_pinned else None
db.commit()
db.refresh(message)
return MessageModel.model_validate(message) if message else None
def get_unread_message_count(
self, channel_id: str, user_id: str, last_read_at: Optional[int] = None
) -> int:
with get_db() as db:
query = db.query(Message).filter(
Message.channel_id == channel_id,
Message.parent_id == None, # only count top-level messages
Message.created_at > (last_read_at if last_read_at else 0),
)
if user_id:
query = query.filter(Message.user_id != user_id)
return query.count()
def add_reaction_to_message(
self, id: str, user_id: str, name: str
) -> Optional[MessageReactionModel]:
with get_db() as db:
# check for existing reaction
existing_reaction = (
db.query(MessageReaction)
.filter_by(message_id=id, user_id=user_id, name=name)
.first()
)
if existing_reaction:
return MessageReactionModel.model_validate(existing_reaction)
reaction_id = str(uuid.uuid4())
reaction = MessageReactionModel(
id=reaction_id,
@ -411,30 +228,17 @@ class MessageTable:
def get_reactions_by_message_id(self, id: str) -> list[Reactions]:
with get_db() as db:
# JOIN User so all user info is fetched in one query
results = (
db.query(MessageReaction, User)
.join(User, MessageReaction.user_id == User.id)
.filter(MessageReaction.message_id == id)
.all()
)
all_reactions = db.query(MessageReaction).filter_by(message_id=id).all()
reactions = {}
for reaction, user in results:
for reaction in all_reactions:
if reaction.name not in reactions:
reactions[reaction.name] = {
"name": reaction.name,
"users": [],
"user_ids": [],
"count": 0,
}
reactions[reaction.name]["users"].append(
{
"id": user.id,
"name": user.name,
}
)
reactions[reaction.name]["user_ids"].append(reaction.user_id)
reactions[reaction.name]["count"] += 1
return [Reactions(**reaction) for reaction in reactions.values()]

View file

@ -3,17 +3,15 @@ import time
from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.groups import Groups
from open_webui.models.users import User, UserModel, Users, UserResponse
from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import String, cast, or_, and_, func
from sqlalchemy import or_, and_, func
from sqlalchemy.dialects import postgresql, sqlite
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
@ -21,6 +19,7 @@ from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
@ -53,7 +52,7 @@ class ModelMeta(BaseModel):
class Model(Base):
__tablename__ = "model"
id = Column(Text, primary_key=True, unique=True)
id = Column(Text, primary_key=True)
"""
The model's id as used in the API. If set to an existing model, it will override the model.
"""
@ -133,11 +132,6 @@ class ModelResponse(ModelModel):
pass
class ModelListResponse(BaseModel):
items: list[ModelUserResponse]
total: int
class ModelForm(BaseModel):
id: str
base_model_id: Optional[str] = None
@ -181,16 +175,9 @@ class ModelsTable:
def get_models(self) -> list[ModelUserResponse]:
with get_db() as db:
all_models = db.query(Model).filter(Model.base_model_id != None).all()
user_ids = list(set(model.user_id for model in all_models))
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
users_dict = {user.id: user for user in users}
models = []
for model in all_models:
user = users_dict.get(model.user_id)
for model in db.query(Model).filter(Model.base_model_id != None).all():
user = Users.get_user_by_id(model.user_id)
models.append(
ModelUserResponse.model_validate(
{
@ -212,143 +199,13 @@ class ModelsTable:
self, user_id: str, permission: str = "write"
) -> list[ModelUserResponse]:
models = self.get_models()
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
return [
model
for model in models
if model.user_id == user_id
or has_access(user_id, permission, model.access_control, user_group_ids)
or has_access(user_id, permission, model.access_control)
]
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
group_ids = filter.get("group_ids", [])
user_id = filter.get("user_id")
dialect_name = db.bind.dialect.name
# Public access
conditions = []
if group_ids or user_id:
conditions.extend(
[
Model.access_control.is_(None),
cast(Model.access_control, String) == "null",
]
)
# User-level permission
if user_id:
conditions.append(Model.user_id == user_id)
# Group-level permission
if group_ids:
group_conditions = []
for gid in group_ids:
if dialect_name == "sqlite":
group_conditions.append(
Model.access_control[permission]["group_ids"].contains([gid])
)
elif dialect_name == "postgresql":
group_conditions.append(
cast(
Model.access_control[permission]["group_ids"],
JSONB,
).contains([gid])
)
conditions.append(or_(*group_conditions))
if conditions:
query = query.filter(or_(*conditions))
return query
def search_models(
self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30
) -> ModelListResponse:
with get_db() as db:
# Join GroupMember so we can order by group_id when requested
query = db.query(Model, User).outerjoin(User, User.id == Model.user_id)
query = query.filter(Model.base_model_id != None)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(
or_(
Model.name.ilike(f"%{query_key}%"),
Model.base_model_id.ilike(f"%{query_key}%"),
)
)
view_option = filter.get("view_option")
if view_option == "created":
query = query.filter(Model.user_id == user_id)
elif view_option == "shared":
query = query.filter(Model.user_id != user_id)
# Apply access control filtering
query = self._has_permission(
db,
query,
filter,
permission="write",
)
tag = filter.get("tag")
if tag:
# TODO: This is a simple implementation and should be improved for performance
like_pattern = f'%"{tag.lower()}"%' # `"tag"` inside JSON array
meta_text = func.lower(cast(Model.meta, String))
query = query.filter(meta_text.like(like_pattern))
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by == "name":
if direction == "asc":
query = query.order_by(Model.name.asc())
else:
query = query.order_by(Model.name.desc())
elif order_by == "created_at":
if direction == "asc":
query = query.order_by(Model.created_at.asc())
else:
query = query.order_by(Model.created_at.desc())
elif order_by == "updated_at":
if direction == "asc":
query = query.order_by(Model.updated_at.asc())
else:
query = query.order_by(Model.updated_at.desc())
else:
query = query.order_by(Model.created_at.desc())
# Count BEFORE pagination
total = query.count()
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
items = query.all()
models = []
for model, user in items:
models.append(
ModelUserResponse(
**ModelModel.model_validate(model).model_dump(),
user=(
UserResponse(**UserModel.model_validate(user).model_dump())
if user
else None
),
)
)
return ModelListResponse(items=models, total=total)
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try:
with get_db() as db:
@ -357,14 +214,6 @@ class ModelsTable:
except Exception:
return None
def get_models_by_ids(self, ids: list[str]) -> list[ModelModel]:
try:
with get_db() as db:
models = db.query(Model).filter(Model.id.in_(ids)).all()
return [ModelModel.model_validate(model) for model in models]
except Exception:
return []
def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
with get_db() as db:
try:
@ -386,9 +235,11 @@ class ModelsTable:
try:
with get_db() as db:
# update only the fields that are present in the model
data = model.model_dump(exclude={"id"})
result = db.query(Model).filter_by(id=id).update(data)
result = (
db.query(Model)
.filter_by(id=id)
.update(model.model_dump(exclude={"id"}))
)
db.commit()
model = db.get(Model, id)
@ -418,49 +269,5 @@ class ModelsTable:
except Exception:
return False
def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]:
try:
with get_db() as db:
# Get existing models
existing_models = db.query(Model).all()
existing_ids = {model.id for model in existing_models}
# Prepare a set of new model IDs
new_model_ids = {model.id for model in models}
# Update or insert models
for model in models:
if model.id in existing_ids:
db.query(Model).filter_by(id=model.id).update(
{
**model.model_dump(),
"user_id": user_id,
"updated_at": int(time.time()),
}
)
else:
new_model = Model(
**{
**model.model_dump(),
"user_id": user_id,
"updated_at": int(time.time()),
}
)
db.add(new_model)
# Remove models that are no longer present
for model in existing_models:
if model.id not in new_model_ids:
db.delete(model)
db.commit()
return [
ModelModel.model_validate(model) for model in db.query(Model).all()
]
except Exception as e:
log.exception(f"Error syncing models for user {user_id}: {e}")
return []
Models = ModelsTable()

View file

@ -2,20 +2,15 @@ import json
import time
import uuid
from typing import Optional
from functools import lru_cache
from open_webui.internal.db import Base, get_db
from open_webui.models.groups import Groups
from open_webui.utils.access_control import has_access
from open_webui.models.users import User, UserModel, Users, UserResponse
from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy import or_, func, select, and_, text, cast, or_, and_, func
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
####################
@ -26,7 +21,7 @@ from sqlalchemy.sql import exists
class Note(Base):
__tablename__ = "note"
id = Column(Text, primary_key=True, unique=True)
id = Column(Text, primary_key=True)
user_id = Column(Text)
title = Column(Text)
@ -67,149 +62,11 @@ class NoteForm(BaseModel):
access_control: Optional[dict] = None
class NoteUpdateForm(BaseModel):
title: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
class NoteUserResponse(NoteModel):
user: Optional[UserResponse] = None
class NoteItemResponse(BaseModel):
id: str
title: str
data: Optional[dict]
updated_at: int
created_at: int
user: Optional[UserResponse] = None
class NoteListResponse(BaseModel):
items: list[NoteUserResponse]
total: int
class NoteTable:
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
group_ids = filter.get("group_ids", [])
user_id = filter.get("user_id")
dialect_name = db.bind.dialect.name
conditions = []
# Handle read_only permission separately
if permission == "read_only":
# For read_only, we want items where:
# 1. User has explicit read permission (via groups or user-level)
# 2. BUT does NOT have write permission
# 3. Public items are NOT considered read_only
read_conditions = []
# Group-level read permission
if group_ids:
group_read_conditions = []
for gid in group_ids:
if dialect_name == "sqlite":
group_read_conditions.append(
Note.access_control["read"]["group_ids"].contains([gid])
)
elif dialect_name == "postgresql":
group_read_conditions.append(
cast(
Note.access_control["read"]["group_ids"],
JSONB,
).contains([gid])
)
if group_read_conditions:
read_conditions.append(or_(*group_read_conditions))
# Combine read conditions
if read_conditions:
has_read = or_(*read_conditions)
else:
# If no read conditions, return empty result
return query.filter(False)
# Now exclude items where user has write permission
write_exclusions = []
# Exclude items owned by user (they have implicit write)
if user_id:
write_exclusions.append(Note.user_id != user_id)
# Exclude items where user has explicit write permission via groups
if group_ids:
group_write_conditions = []
for gid in group_ids:
if dialect_name == "sqlite":
group_write_conditions.append(
Note.access_control["write"]["group_ids"].contains([gid])
)
elif dialect_name == "postgresql":
group_write_conditions.append(
cast(
Note.access_control["write"]["group_ids"],
JSONB,
).contains([gid])
)
if group_write_conditions:
# User should NOT have write permission
write_exclusions.append(~or_(*group_write_conditions))
# Exclude public items (items without access_control)
write_exclusions.append(Note.access_control.isnot(None))
write_exclusions.append(cast(Note.access_control, String) != "null")
# Combine: has read AND does not have write AND not public
if write_exclusions:
query = query.filter(and_(has_read, *write_exclusions))
else:
query = query.filter(has_read)
return query
# Original logic for other permissions (read, write, etc.)
# Public access conditions
if group_ids or user_id:
conditions.extend(
[
Note.access_control.is_(None),
cast(Note.access_control, String) == "null",
]
)
# User-level permission (owner has all permissions)
if user_id:
conditions.append(Note.user_id == user_id)
# Group-level permission
if group_ids:
group_conditions = []
for gid in group_ids:
if dialect_name == "sqlite":
group_conditions.append(
Note.access_control[permission]["group_ids"].contains([gid])
)
elif dialect_name == "postgresql":
group_conditions.append(
cast(
Note.access_control[permission]["group_ids"],
JSONB,
).contains([gid])
)
conditions.append(or_(*group_conditions))
if conditions:
query = query.filter(or_(*conditions))
return query
def insert_new_note(
self,
form_data: NoteForm,
@ -232,153 +89,37 @@ class NoteTable:
db.commit()
return note
def get_notes(
self, skip: Optional[int] = None, limit: Optional[int] = None
) -> list[NoteModel]:
def get_notes(self) -> list[NoteModel]:
with get_db() as db:
query = db.query(Note).order_by(Note.updated_at.desc())
if skip is not None:
query = query.offset(skip)
if limit is not None:
query = query.limit(limit)
notes = query.all()
notes = db.query(Note).order_by(Note.updated_at.desc()).all()
return [NoteModel.model_validate(note) for note in notes]
def search_notes(
self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30
) -> NoteListResponse:
with get_db() as db:
query = db.query(Note, User).outerjoin(User, User.id == Note.user_id)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(
or_(
Note.title.ilike(f"%{query_key}%"),
cast(Note.data["content"]["md"], Text).ilike(
f"%{query_key}%"
),
)
)
view_option = filter.get("view_option")
if view_option == "created":
query = query.filter(Note.user_id == user_id)
elif view_option == "shared":
query = query.filter(Note.user_id != user_id)
# Apply access control filtering
if "permission" in filter:
permission = filter["permission"]
else:
permission = "write"
query = self._has_permission(
db,
query,
filter,
permission=permission,
)
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by == "name":
if direction == "asc":
query = query.order_by(Note.title.asc())
else:
query = query.order_by(Note.title.desc())
elif order_by == "created_at":
if direction == "asc":
query = query.order_by(Note.created_at.asc())
else:
query = query.order_by(Note.created_at.desc())
elif order_by == "updated_at":
if direction == "asc":
query = query.order_by(Note.updated_at.asc())
else:
query = query.order_by(Note.updated_at.desc())
else:
query = query.order_by(Note.updated_at.desc())
else:
query = query.order_by(Note.updated_at.desc())
# Count BEFORE pagination
total = query.count()
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
items = query.all()
notes = []
for note, user in items:
notes.append(
NoteUserResponse(
**NoteModel.model_validate(note).model_dump(),
user=(
UserResponse(**UserModel.model_validate(user).model_dump())
if user
else None
),
)
)
return NoteListResponse(items=notes, total=total)
def get_notes_by_user_id(
self,
user_id: str,
permission: str = "read",
skip: Optional[int] = None,
limit: Optional[int] = None,
self, user_id: str, permission: str = "write"
) -> list[NoteModel]:
with get_db() as db:
user_group_ids = [
group.id for group in Groups.get_groups_by_member_id(user_id)
]
query = db.query(Note).order_by(Note.updated_at.desc())
query = self._has_permission(
db, query, {"user_id": user_id, "group_ids": user_group_ids}, permission
)
if skip is not None:
query = query.offset(skip)
if limit is not None:
query = query.limit(limit)
notes = query.all()
return [NoteModel.model_validate(note) for note in notes]
notes = self.get_notes()
return [
note
for note in notes
if note.user_id == user_id
or has_access(user_id, permission, note.access_control)
]
def get_note_by_id(self, id: str) -> Optional[NoteModel]:
with get_db() as db:
note = db.query(Note).filter(Note.id == id).first()
return NoteModel.model_validate(note) if note else None
def update_note_by_id(
self, id: str, form_data: NoteUpdateForm
) -> Optional[NoteModel]:
def update_note_by_id(self, id: str, form_data: NoteForm) -> Optional[NoteModel]:
with get_db() as db:
note = db.query(Note).filter(Note.id == id).first()
if not note:
return None
form_data = form_data.model_dump(exclude_unset=True)
if "title" in form_data:
note.title = form_data["title"]
if "data" in form_data:
note.data = {**note.data, **form_data["data"]}
if "meta" in form_data:
note.meta = {**note.meta, **form_data["meta"]}
if "access_control" in form_data:
note.access_control = form_data["access_control"]
note.title = form_data.title
note.data = form_data.data
note.meta = form_data.meta
note.access_control = form_data.access_control
note.updated_at = int(time.time_ns())
db.commit()

View file

@ -1,276 +0,0 @@
import time
import logging
import uuid
from typing import Optional, List
import base64
import hashlib
import json
from cryptography.fernet import Fernet
from open_webui.internal.db import Base, get_db
from open_webui.env import OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, Index
log = logging.getLogger(__name__)
####################
# DB MODEL
####################
class OAuthSession(Base):
__tablename__ = "oauth_session"
id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text, nullable=False)
provider = Column(Text, nullable=False)
token = Column(
Text, nullable=False
) # JSON with access_token, id_token, refresh_token
expires_at = Column(BigInteger, nullable=False)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
# Add indexes for better performance
__table_args__ = (
Index("idx_oauth_session_user_id", "user_id"),
Index("idx_oauth_session_expires_at", "expires_at"),
Index("idx_oauth_session_user_provider", "user_id", "provider"),
)
class OAuthSessionModel(BaseModel):
id: str
user_id: str
provider: str
token: dict
expires_at: int # timestamp in epoch
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class OAuthSessionResponse(BaseModel):
id: str
user_id: str
provider: str
expires_at: int
class OAuthSessionTable:
def __init__(self):
self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
if not self.encryption_key:
raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set")
# check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes)
if len(self.encryption_key) != 44:
key_bytes = hashlib.sha256(self.encryption_key.encode()).digest()
self.encryption_key = base64.urlsafe_b64encode(key_bytes)
else:
self.encryption_key = self.encryption_key.encode()
try:
self.fernet = Fernet(self.encryption_key)
except Exception as e:
log.error(f"Error initializing Fernet with provided key: {e}")
raise
def _encrypt_token(self, token) -> str:
"""Encrypt OAuth tokens for storage"""
try:
token_json = json.dumps(token)
encrypted = self.fernet.encrypt(token_json.encode()).decode()
return encrypted
except Exception as e:
log.error(f"Error encrypting tokens: {e}")
raise
def _decrypt_token(self, token: str):
"""Decrypt OAuth tokens from storage"""
try:
decrypted = self.fernet.decrypt(token.encode()).decode()
return json.loads(decrypted)
except Exception as e:
log.error(f"Error decrypting tokens: {e}")
raise
def create_session(
self,
user_id: str,
provider: str,
token: dict,
) -> Optional[OAuthSessionModel]:
"""Create a new OAuth session"""
try:
with get_db() as db:
current_time = int(time.time())
id = str(uuid.uuid4())
result = OAuthSession(
**{
"id": id,
"user_id": user_id,
"provider": provider,
"token": self._encrypt_token(token),
"expires_at": token.get("expires_at"),
"created_at": current_time,
"updated_at": current_time,
}
)
db.add(result)
db.commit()
db.refresh(result)
if result:
result.token = token # Return decrypted token
return OAuthSessionModel.model_validate(result)
else:
return None
except Exception as e:
log.error(f"Error creating OAuth session: {e}")
return None
def get_session_by_id(self, session_id: str) -> Optional[OAuthSessionModel]:
"""Get OAuth session by ID"""
try:
with get_db() as db:
session = db.query(OAuthSession).filter_by(id=session_id).first()
if session:
session.token = self._decrypt_token(session.token)
return OAuthSessionModel.model_validate(session)
return None
except Exception as e:
log.error(f"Error getting OAuth session by ID: {e}")
return None
def get_session_by_id_and_user_id(
self, session_id: str, user_id: str
) -> Optional[OAuthSessionModel]:
"""Get OAuth session by ID and user ID"""
try:
with get_db() as db:
session = (
db.query(OAuthSession)
.filter_by(id=session_id, user_id=user_id)
.first()
)
if session:
session.token = self._decrypt_token(session.token)
return OAuthSessionModel.model_validate(session)
return None
except Exception as e:
log.error(f"Error getting OAuth session by ID: {e}")
return None
def get_session_by_provider_and_user_id(
self, provider: str, user_id: str
) -> Optional[OAuthSessionModel]:
"""Get OAuth session by provider and user ID"""
try:
with get_db() as db:
session = (
db.query(OAuthSession)
.filter_by(provider=provider, user_id=user_id)
.first()
)
if session:
session.token = self._decrypt_token(session.token)
return OAuthSessionModel.model_validate(session)
return None
except Exception as e:
log.error(f"Error getting OAuth session by provider and user ID: {e}")
return None
def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]:
"""Get all OAuth sessions for a user"""
try:
with get_db() as db:
sessions = db.query(OAuthSession).filter_by(user_id=user_id).all()
results = []
for session in sessions:
session.token = self._decrypt_token(session.token)
results.append(OAuthSessionModel.model_validate(session))
return results
except Exception as e:
log.error(f"Error getting OAuth sessions by user ID: {e}")
return []
def update_session_by_id(
self, session_id: str, token: dict
) -> Optional[OAuthSessionModel]:
"""Update OAuth session tokens"""
try:
with get_db() as db:
current_time = int(time.time())
db.query(OAuthSession).filter_by(id=session_id).update(
{
"token": self._encrypt_token(token),
"expires_at": token.get("expires_at"),
"updated_at": current_time,
}
)
db.commit()
session = db.query(OAuthSession).filter_by(id=session_id).first()
if session:
session.token = self._decrypt_token(session.token)
return OAuthSessionModel.model_validate(session)
return None
except Exception as e:
log.error(f"Error updating OAuth session tokens: {e}")
return None
def delete_session_by_id(self, session_id: str) -> bool:
"""Delete an OAuth session"""
try:
with get_db() as db:
result = db.query(OAuthSession).filter_by(id=session_id).delete()
db.commit()
return result > 0
except Exception as e:
log.error(f"Error deleting OAuth session: {e}")
return False
def delete_sessions_by_user_id(self, user_id: str) -> bool:
"""Delete all OAuth sessions for a user"""
try:
with get_db() as db:
result = db.query(OAuthSession).filter_by(user_id=user_id).delete()
db.commit()
return True
except Exception as e:
log.error(f"Error deleting OAuth sessions by user ID: {e}")
return False
def delete_sessions_by_provider(self, provider: str) -> bool:
"""Delete all OAuth sessions for a provider"""
try:
with get_db() as db:
db.query(OAuthSession).filter_by(provider=provider).delete()
db.commit()
return True
except Exception as e:
log.error(f"Error deleting OAuth sessions by provider {provider}: {e}")
return False
OAuthSessions = OAuthSessionTable()

View file

@ -2,7 +2,6 @@ import time
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.groups import Groups
from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict
@ -104,16 +103,10 @@ class PromptsTable:
def get_prompts(self) -> list[PromptUserResponse]:
with get_db() as db:
all_prompts = db.query(Prompt).order_by(Prompt.timestamp.desc()).all()
user_ids = list(set(prompt.user_id for prompt in all_prompts))
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
users_dict = {user.id: user for user in users}
prompts = []
for prompt in all_prompts:
user = users_dict.get(prompt.user_id)
for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all():
user = Users.get_user_by_id(prompt.user_id)
prompts.append(
PromptUserResponse.model_validate(
{
@ -129,13 +122,12 @@ class PromptsTable:
self, user_id: str, permission: str = "write"
) -> list[PromptUserResponse]:
prompts = self.get_prompts()
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
return [
prompt
for prompt in prompts
if prompt.user_id == user_id
or has_access(user_id, permission, prompt.access_control, user_group_ids)
or has_access(user_id, permission, prompt.access_control)
]
def update_prompt_by_command(

View file

@ -6,10 +6,12 @@ from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint, Index
from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
@ -22,11 +24,6 @@ class Tag(Base):
user_id = Column(String)
meta = Column(JSON, nullable=True)
__table_args__ = (
PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),
Index("user_id_idx", "user_id"),
)
# Unique constraint ensuring (id, user_id) is unique, not just the `id` column
__table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),)

View file

@ -4,8 +4,7 @@ from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.users import Users, UserResponse
from open_webui.models.groups import Groups
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
@ -13,6 +12,7 @@ from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Tools DB Schema
@ -22,7 +22,7 @@ log = logging.getLogger(__name__)
class Tool(Base):
__tablename__ = "tool"
id = Column(String, primary_key=True, unique=True)
id = Column(String, primary_key=True)
user_id = Column(String)
name = Column(Text)
content = Column(Text)
@ -93,8 +93,6 @@ class ToolResponse(BaseModel):
class ToolUserResponse(ToolResponse):
user: Optional[UserResponse] = None
model_config = ConfigDict(extra="allow")
class ToolForm(BaseModel):
id: str
@ -146,16 +144,9 @@ class ToolsTable:
def get_tools(self) -> list[ToolUserModel]:
with get_db() as db:
all_tools = db.query(Tool).order_by(Tool.updated_at.desc()).all()
user_ids = list(set(tool.user_id for tool in all_tools))
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
users_dict = {user.id: user for user in users}
tools = []
for tool in all_tools:
user = users_dict.get(tool.user_id)
for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all():
user = Users.get_user_by_id(tool.user_id)
tools.append(
ToolUserModel.model_validate(
{
@ -170,13 +161,12 @@ class ToolsTable:
self, user_id: str, permission: str = "write"
) -> list[ToolUserModel]:
tools = self.get_tools()
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
return [
tool
for tool in tools
if tool.user_id == user_id
or has_access(user_id, permission, tool.access_control, user_group_ids)
or has_access(user_id, permission, tool.access_control)
]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
@ -185,7 +175,7 @@ class ToolsTable:
tool = db.get(Tool, id)
return tool.valves if tool.valves else {}
except Exception as e:
log.exception(f"Error getting tool valves by id {id}")
log.exception(f"Error getting tool valves by id {id}: {e}")
return None
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:

View file

@ -4,139 +4,62 @@ from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
from open_webui.models.chats import Chats
from open_webui.models.groups import Groups, GroupMember
from open_webui.models.channels import ChannelMember
from open_webui.utils.misc import throttle
from open_webui.models.groups import Groups
from pydantic import BaseModel, ConfigDict
from sqlalchemy import (
BigInteger,
JSON,
Column,
String,
Boolean,
Text,
Date,
exists,
select,
cast,
)
from sqlalchemy import or_, case
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy import BigInteger, Column, String, Text
from sqlalchemy import or_
import datetime
####################
# User DB Schema
####################
class User(Base):
__tablename__ = "user"
id = Column(String, primary_key=True)
name = Column(String)
email = Column(String)
role = Column(String)
profile_image_url = Column(Text)
last_active_at = Column(BigInteger)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
api_key = Column(String, nullable=True, unique=True)
settings = Column(JSONField, nullable=True)
info = Column(JSONField, nullable=True)
oauth_sub = Column(Text, unique=True)
class UserSettings(BaseModel):
ui: Optional[dict] = {}
model_config = ConfigDict(extra="allow")
pass
class User(Base):
__tablename__ = "user"
id = Column(String, primary_key=True, unique=True)
email = Column(String)
username = Column(String(50), nullable=True)
role = Column(String)
name = Column(String)
profile_image_url = Column(Text)
profile_banner_image_url = Column(Text, nullable=True)
bio = Column(Text, nullable=True)
gender = Column(Text, nullable=True)
date_of_birth = Column(Date, nullable=True)
timezone = Column(String, nullable=True)
presence_state = Column(String, nullable=True)
status_emoji = Column(String, nullable=True)
status_message = Column(Text, nullable=True)
status_expires_at = Column(BigInteger, nullable=True)
info = Column(JSON, nullable=True)
settings = Column(JSON, nullable=True)
oauth = Column(JSON, nullable=True)
last_active_at = Column(BigInteger)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class UserModel(BaseModel):
id: str
email: str
username: Optional[str] = None
role: str = "pending"
name: str
email: str
role: str = "pending"
profile_image_url: str
profile_banner_image_url: Optional[str] = None
bio: Optional[str] = None
gender: Optional[str] = None
date_of_birth: Optional[datetime.date] = None
timezone: Optional[str] = None
presence_state: Optional[str] = None
status_emoji: Optional[str] = None
status_message: Optional[str] = None
status_expires_at: Optional[int] = None
info: Optional[dict] = None
settings: Optional[UserSettings] = None
oauth: Optional[dict] = None
last_active_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
api_key: Optional[str] = None
settings: Optional[UserSettings] = None
info: Optional[dict] = None
class UserStatusModel(UserModel):
is_active: bool = False
model_config = ConfigDict(from_attributes=True)
class ApiKey(Base):
__tablename__ = "api_key"
id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text, nullable=False)
key = Column(Text, unique=True, nullable=False)
data = Column(JSON, nullable=True)
expires_at = Column(BigInteger, nullable=True)
last_used_at = Column(BigInteger, nullable=True)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
class ApiKeyModel(BaseModel):
id: str
user_id: str
key: str
data: Optional[dict] = None
expires_at: Optional[int] = None
last_used_at: Optional[int] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
oauth_sub: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
@ -146,78 +69,23 @@ class ApiKeyModel(BaseModel):
####################
class UpdateProfileForm(BaseModel):
profile_image_url: str
name: str
bio: Optional[str] = None
gender: Optional[str] = None
date_of_birth: Optional[datetime.date] = None
class UserGroupIdsModel(UserModel):
group_ids: list[str] = []
class UserModelResponse(UserModel):
model_config = ConfigDict(extra="allow")
class UserListResponse(BaseModel):
users: list[UserModelResponse]
users: list[UserModel]
total: int
class UserGroupIdsListResponse(BaseModel):
users: list[UserGroupIdsModel]
total: int
class UserStatus(BaseModel):
status_emoji: Optional[str] = None
status_message: Optional[str] = None
status_expires_at: Optional[int] = None
class UserInfoResponse(UserStatus):
class UserResponse(BaseModel):
id: str
name: str
email: str
role: str
class UserIdNameResponse(BaseModel):
id: str
name: str
class UserIdNameStatusResponse(UserStatus):
id: str
name: str
is_active: Optional[bool] = None
class UserInfoListResponse(BaseModel):
users: list[UserInfoResponse]
total: int
class UserIdNameListResponse(BaseModel):
users: list[UserIdNameResponse]
total: int
profile_image_url: str
class UserNameResponse(BaseModel):
id: str
name: str
role: str
class UserResponse(UserNameResponse):
email: str
class UserProfileImageResponse(UserNameResponse):
email: str
profile_image_url: str
@ -227,7 +95,6 @@ class UserRoleUpdateForm(BaseModel):
class UserUpdateForm(BaseModel):
role: str
name: str
email: str
profile_image_url: str
@ -242,20 +109,20 @@ class UsersTable:
email: str,
profile_image_url: str = "/user.png",
role: str = "pending",
oauth: Optional[dict] = None,
oauth_sub: Optional[str] = None,
) -> Optional[UserModel]:
with get_db() as db:
user = UserModel(
**{
"id": id,
"email": email,
"name": name,
"email": email,
"role": role,
"profile_image_url": profile_image_url,
"last_active_at": int(time.time()),
"created_at": int(time.time()),
"updated_at": int(time.time()),
"oauth": oauth,
"oauth_sub": oauth_sub,
}
)
result = User(**user.model_dump())
@ -278,13 +145,8 @@ class UsersTable:
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try:
with get_db() as db:
user = (
db.query(User)
.join(ApiKey, User.id == ApiKey.user_id)
.filter(ApiKey.key == api_key)
.first()
)
return UserModel.model_validate(user) if user else None
user = db.query(User).filter_by(api_key=api_key).first()
return UserModel.model_validate(user)
except Exception:
return None
@ -296,23 +158,12 @@ class UsersTable:
except Exception:
return None
def get_user_by_oauth_sub(self, provider: str, sub: str) -> Optional[UserModel]:
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try:
with get_db() as db: # type: Session
dialect_name = db.bind.dialect.name
query = db.query(User)
if dialect_name == "sqlite":
query = query.filter(User.oauth.contains({provider: {"sub": sub}}))
elif dialect_name == "postgresql":
query = query.filter(
User.oauth[provider].cast(JSONB)["sub"].astext == sub
)
user = query.first()
return UserModel.model_validate(user) if user else None
except Exception as e:
# You may want to log the exception here
with get_db() as db:
user = db.query(User).filter_by(oauth_sub=sub).first()
return UserModel.model_validate(user)
except Exception:
return None
def get_users(
@ -320,9 +171,8 @@ class UsersTable:
filter: Optional[dict] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
) -> dict:
) -> UserListResponse:
with get_db() as db:
# Join GroupMember so we can order by group_id when requested
query = db.query(User)
if filter:
@ -335,76 +185,14 @@ class UsersTable:
)
)
channel_id = filter.get("channel_id")
if channel_id:
query = query.filter(
exists(
select(ChannelMember.id).where(
ChannelMember.user_id == User.id,
ChannelMember.channel_id == channel_id,
)
)
)
user_ids = filter.get("user_ids")
group_ids = filter.get("group_ids")
if isinstance(user_ids, list) and isinstance(group_ids, list):
# If both are empty lists, return no users
if not user_ids and not group_ids:
return {"users": [], "total": 0}
if user_ids:
query = query.filter(User.id.in_(user_ids))
if group_ids:
query = query.filter(
exists(
select(GroupMember.id).where(
GroupMember.user_id == User.id,
GroupMember.group_id.in_(group_ids),
)
)
)
roles = filter.get("roles")
if roles:
include_roles = [role for role in roles if not role.startswith("!")]
exclude_roles = [role[1:] for role in roles if role.startswith("!")]
if include_roles:
query = query.filter(User.role.in_(include_roles))
if exclude_roles:
query = query.filter(~User.role.in_(exclude_roles))
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by and order_by.startswith("group_id:"):
group_id = order_by.split(":", 1)[1]
# Subquery that checks if the user belongs to the group
membership_exists = exists(
select(GroupMember.id).where(
GroupMember.user_id == User.id,
GroupMember.group_id == group_id,
)
)
# CASE: user in group → 1, user not in group → 0
group_sort = case((membership_exists, 1), else_=0)
if direction == "asc":
query = query.order_by(group_sort.asc(), User.name.asc())
else:
query = query.order_by(group_sort.desc(), User.name.asc())
elif order_by == "name":
if order_by == "name":
if direction == "asc":
query = query.order_by(User.name.asc())
else:
query = query.order_by(User.name.desc())
elif order_by == "email":
if direction == "asc":
query = query.order_by(User.email.asc())
@ -437,32 +225,18 @@ class UsersTable:
else:
query = query.order_by(User.created_at.desc())
# Count BEFORE pagination
total = query.count()
# correct pagination logic
if skip is not None:
if skip:
query = query.offset(skip)
if limit is not None:
if limit:
query = query.limit(limit)
users = query.all()
return {
"users": [UserModel.model_validate(user) for user in users],
"total": total,
"total": db.query(User).count(),
}
def get_users_by_group_id(self, group_id: str) -> list[UserModel]:
with get_db() as db:
users = (
db.query(User)
.join(GroupMember, User.id == GroupMember.user_id)
.filter(GroupMember.group_id == group_id)
.all()
)
return [UserModel.model_validate(user) for user in users]
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserStatusModel]:
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
with get_db() as db:
users = db.query(User).filter(User.id.in_(user_ids)).all()
return [UserModel.model_validate(user) for user in users]
@ -471,10 +245,6 @@ class UsersTable:
with get_db() as db:
return db.query(User).count()
def has_users(self) -> bool:
with get_db() as db:
return db.query(db.query(User).exists()).scalar()
def get_first_user(self) -> UserModel:
try:
with get_db() as db:
@ -499,15 +269,6 @@ class UsersTable:
except Exception:
return None
def get_num_users_active_today(self) -> Optional[int]:
with get_db() as db:
current_timestamp = int(datetime.datetime.now().timestamp())
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
query = db.query(User).filter(
User.last_active_at > today_midnight_timestamp
)
return query.count()
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
try:
with get_db() as db:
@ -518,21 +279,6 @@ class UsersTable:
except Exception:
return None
def update_user_status_by_id(
self, id: str, form_data: UserStatus
) -> Optional[UserModel]:
try:
with get_db() as db:
db.query(User).filter_by(id=id).update(
{**form_data.model_dump(exclude_none=True)}
)
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception:
return None
def update_user_profile_image_url_by_id(
self, id: str, profile_image_url: str
) -> Optional[UserModel]:
@ -548,8 +294,7 @@ class UsersTable:
except Exception:
return None
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
def update_last_active_by_id(self, id: str) -> Optional[UserModel]:
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
try:
with get_db() as db:
db.query(User).filter_by(id=id).update(
@ -562,35 +307,16 @@ class UsersTable:
except Exception:
return None
def update_user_oauth_by_id(
self, id: str, provider: str, sub: str
def update_user_oauth_sub_by_id(
self, id: str, oauth_sub: str
) -> Optional[UserModel]:
"""
Update or insert an OAuth provider/sub pair into the user's oauth JSON field.
Example resulting structure:
{
"google": { "sub": "123" },
"github": { "sub": "abc" }
}
"""
try:
with get_db() as db:
user = db.query(User).filter_by(id=id).first()
if not user:
return None
# Load existing oauth JSON or create empty
oauth = user.oauth or {}
# Update or insert provider entry
oauth[provider] = {"sub": sub}
# Persist updated JSON
db.query(User).filter_by(id=id).update({"oauth": oauth})
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception:
return None
@ -603,8 +329,7 @@ class UsersTable:
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
# return UserModel(**user.dict())
except Exception as e:
print(e)
except Exception:
return None
def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
@ -644,45 +369,23 @@ class UsersTable:
except Exception:
return False
def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
try:
with get_db() as db:
result = db.query(User).filter_by(id=id).update({"api_key": api_key})
db.commit()
return True if result == 1 else False
except Exception:
return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]:
try:
with get_db() as db:
api_key = db.query(ApiKey).filter_by(user_id=id).first()
return api_key.key if api_key else None
user = db.query(User).filter_by(id=id).first()
return user.api_key
except Exception:
return None
def update_user_api_key_by_id(self, id: str, api_key: str) -> bool:
try:
with get_db() as db:
db.query(ApiKey).filter_by(user_id=id).delete()
db.commit()
now = int(time.time())
new_api_key = ApiKey(
id=f"key_{id}",
user_id=id,
key=api_key,
created_at=now,
updated_at=now,
)
db.add(new_api_key)
db.commit()
return True
except Exception:
return False
def delete_user_api_key_by_id(self, id: str) -> bool:
try:
with get_db() as db:
db.query(ApiKey).filter_by(user_id=id).delete()
db.commit()
return True
except Exception:
return False
def get_valid_user_ids(self, user_ids: list[str]) -> list[str]:
with get_db() as db:
users = db.query(User).filter(User.id.in_(user_ids)).all()
@ -696,23 +399,5 @@ class UsersTable:
else:
return None
def get_active_user_count(self) -> int:
with get_db() as db:
# Consider user active if last_active_at within the last 3 minutes
three_minutes_ago = int(time.time()) - 180
count = (
db.query(User).filter(User.last_active_at >= three_minutes_ago).count()
)
return count
def is_user_active(self, user_id: str) -> bool:
with get_db() as db:
user = db.query(User).filter_by(id=user_id).first()
if user and user.last_active_at:
# Consider user active if last_active_at within the last 3 minutes
three_minutes_ago = int(time.time()) - 180
return user.last_active_at >= three_minutes_ago
return False
Users = UsersTable()

View file

@ -15,28 +15,24 @@ class DatalabMarkerLoader:
self,
file_path: str,
api_key: str,
api_base_url: str,
additional_config: Optional[str] = None,
langs: Optional[str] = None,
use_llm: bool = False,
skip_cache: bool = False,
force_ocr: bool = False,
paginate: bool = False,
strip_existing_ocr: bool = False,
disable_image_extraction: bool = False,
format_lines: bool = False,
output_format: str = None,
):
self.file_path = file_path
self.api_key = api_key
self.api_base_url = api_base_url
self.additional_config = additional_config
self.langs = langs
self.use_llm = use_llm
self.skip_cache = skip_cache
self.force_ocr = force_ocr
self.paginate = paginate
self.strip_existing_ocr = strip_existing_ocr
self.disable_image_extraction = disable_image_extraction
self.format_lines = format_lines
self.output_format = output_format
def _get_mime_type(self, filename: str) -> str:
@ -64,7 +60,7 @@ class DatalabMarkerLoader:
return mime_map.get(ext, "application/octet-stream")
def check_marker_request_status(self, request_id: str) -> dict:
url = f"{self.api_base_url}/{request_id}"
url = f"https://www.datalab.to/api/v1/marker/{request_id}"
headers = {"X-Api-Key": self.api_key}
try:
response = requests.get(url, headers=headers)
@ -85,24 +81,22 @@ class DatalabMarkerLoader:
)
def load(self) -> List[Document]:
url = "https://www.datalab.to/api/v1/marker"
filename = os.path.basename(self.file_path)
mime_type = self._get_mime_type(filename)
headers = {"X-Api-Key": self.api_key}
form_data = {
"langs": self.langs,
"use_llm": str(self.use_llm).lower(),
"skip_cache": str(self.skip_cache).lower(),
"force_ocr": str(self.force_ocr).lower(),
"paginate": str(self.paginate).lower(),
"strip_existing_ocr": str(self.strip_existing_ocr).lower(),
"disable_image_extraction": str(self.disable_image_extraction).lower(),
"format_lines": str(self.format_lines).lower(),
"output_format": self.output_format,
}
if self.additional_config and self.additional_config.strip():
form_data["additional_config"] = self.additional_config
log.info(
f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}"
)
@ -111,10 +105,7 @@ class DatalabMarkerLoader:
with open(self.file_path, "rb") as f:
files = {"file": (filename, f, mime_type)}
response = requests.post(
f"{self.api_base_url}",
data=form_data,
files=files,
headers=headers,
url, data=form_data, files=files, headers=headers
)
response.raise_for_status()
result = response.json()
@ -142,92 +133,74 @@ class DatalabMarkerLoader:
check_url = result.get("request_check_url")
request_id = result.get("request_id")
if not check_url:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail="No request_check_url returned."
)
# Check if this is a direct response (self-hosted) or polling response (DataLab)
if check_url:
# DataLab polling pattern
for _ in range(300): # Up to 10 minutes
time.sleep(2)
try:
poll_response = requests.get(check_url, headers=headers)
poll_response.raise_for_status()
poll_result = poll_response.json()
except (requests.HTTPError, ValueError) as e:
raw_body = poll_response.text
log.error(f"Polling error: {e}, response body: {raw_body}")
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}"
)
status_val = poll_result.get("status")
success_val = poll_result.get("success")
if status_val == "complete":
summary = {
k: poll_result.get(k)
for k in (
"status",
"output_format",
"success",
"error",
"page_count",
"total_cost",
)
}
log.info(
f"Marker processing completed successfully: {json.dumps(summary, indent=2)}"
)
break
if status_val == "failed" or success_val is False:
log.error(
f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}"
)
error_msg = (
poll_result.get("error")
or "Marker returned failure without error message"
)
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Marker processing failed: {error_msg}",
)
else:
for _ in range(300): # Up to 10 minutes
time.sleep(2)
try:
poll_response = requests.get(check_url, headers=headers)
poll_response.raise_for_status()
poll_result = poll_response.json()
except (requests.HTTPError, ValueError) as e:
raw_body = poll_response.text
log.error(f"Polling error: {e}, response body: {raw_body}")
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT,
detail="Marker processing timed out",
status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}"
)
if not poll_result.get("success", False):
error_msg = poll_result.get("error") or "Unknown processing error"
status_val = poll_result.get("status")
success_val = poll_result.get("success")
if status_val == "complete":
summary = {
k: poll_result.get(k)
for k in (
"status",
"output_format",
"success",
"error",
"page_count",
"total_cost",
)
}
log.info(
f"Marker processing completed successfully: {json.dumps(summary, indent=2)}"
)
break
if status_val == "failed" or success_val is False:
log.error(
f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}"
)
error_msg = (
poll_result.get("error")
or "Marker returned failure without error message"
)
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Final processing failed: {error_msg}",
detail=f"Marker processing failed: {error_msg}",
)
# DataLab format - content in format-specific fields
content_key = self.output_format.lower()
raw_content = poll_result.get(content_key)
final_result = poll_result
else:
# Self-hosted direct response - content in "output" field
if "output" in result:
log.info("Self-hosted Marker returned direct response without polling")
raw_content = result.get("output")
final_result = result
else:
available_fields = (
list(result.keys())
if isinstance(result, dict)
else "non-dict response"
)
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Custom Marker endpoint returned success but no 'output' field found. Available fields: {available_fields}. Expected either 'request_check_url' for polling or 'output' field for direct response.",
)
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT, detail="Marker processing timed out"
)
if self.output_format.lower() == "json":
if not poll_result.get("success", False):
error_msg = poll_result.get("error") or "Unknown processing error"
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Final processing failed: {error_msg}",
)
content_key = self.output_format.lower()
raw_content = poll_result.get(content_key)
if content_key == "json":
full_text = json.dumps(raw_content, indent=2)
elif self.output_format.lower() in {"markdown", "html"}:
elif content_key in {"markdown", "html"}:
full_text = str(raw_content).strip()
else:
raise HTTPException(
@ -238,14 +211,14 @@ class DatalabMarkerLoader:
if not full_text:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="Marker returned empty content",
detail="Datalab Marker returned empty content",
)
marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output")
os.makedirs(marker_output_dir, exist_ok=True)
file_ext_map = {"markdown": "md", "json": "json", "html": "html"}
file_ext = file_ext_map.get(self.output_format.lower(), "txt")
file_ext = file_ext_map.get(content_key, "txt")
output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}"
output_path = os.path.join(marker_output_dir, output_filename)
@ -258,13 +231,13 @@ class DatalabMarkerLoader:
metadata = {
"source": filename,
"output_format": final_result.get("output_format", self.output_format),
"page_count": final_result.get("page_count", 0),
"output_format": poll_result.get("output_format", self.output_format),
"page_count": poll_result.get("page_count", 0),
"processed_with_llm": self.use_llm,
"request_id": request_id or "",
}
images = final_result.get("images", {})
images = poll_result.get("images", {})
if images:
metadata["image_count"] = len(images)
metadata["images"] = json.dumps(list(images.keys()))

View file

@ -1,13 +1,13 @@
import requests
import logging, os
import logging
from typing import Iterator, List, Union
from urllib.parse import quote
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from open_webui.utils.headers import include_user_info_headers
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ExternalDocumentLoader(BaseLoader):
@ -17,7 +17,6 @@ class ExternalDocumentLoader(BaseLoader):
url: str,
api_key: str,
mime_type=None,
user=None,
**kwargs,
) -> None:
self.url = url
@ -26,9 +25,7 @@ class ExternalDocumentLoader(BaseLoader):
self.file_path = file_path
self.mime_type = mime_type
self.user = user
def load(self) -> List[Document]:
def load(self) -> list[Document]:
with open(self.file_path, "rb") as f:
data = f.read()
@ -39,51 +36,23 @@ class ExternalDocumentLoader(BaseLoader):
if self.api_key is not None:
headers["Authorization"] = f"Bearer {self.api_key}"
try:
headers["X-Filename"] = quote(os.path.basename(self.file_path))
except:
pass
if self.user is not None:
headers = include_user_info_headers(headers, self.user)
url = self.url
if url.endswith("/"):
url = url[:-1]
try:
response = requests.put(f"{url}/process", data=data, headers=headers)
except Exception as e:
log.error(f"Error connecting to endpoint: {e}")
raise Exception(f"Error connecting to endpoint: {e}")
r = requests.put(f"{url}/process", data=data, headers=headers)
if response.ok:
response_data = response.json()
if response_data:
if isinstance(response_data, dict):
return [
Document(
page_content=response_data.get("page_content"),
metadata=response_data.get("metadata"),
)
]
elif isinstance(response_data, list):
documents = []
for document in response_data:
documents.append(
Document(
page_content=document.get("page_content"),
metadata=document.get("metadata"),
)
)
return documents
else:
raise Exception("Error loading document: Unable to parse content")
if r.ok:
res = r.json()
if res:
return [
Document(
page_content=res.get("page_content"),
metadata=res.get("metadata"),
)
]
else:
raise Exception("Error loading document: No content returned")
else:
raise Exception(
f"Error loading document: {response.status_code} {response.text}"
)
raise Exception(f"Error loading document: {r.status_code} {r.text}")

View file

@ -4,8 +4,10 @@ from typing import Iterator, List, Union
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ExternalWebLoader(BaseLoader):

View file

@ -2,9 +2,7 @@ import requests
import logging
import ftfy
import sys
import json
from azure.identity import DefaultAzureCredential
from langchain_community.document_loaders import (
AzureAIDocumentIntelligenceLoader,
BSHTMLLoader,
@ -15,7 +13,7 @@ from langchain_community.document_loaders import (
TextLoader,
UnstructuredEPubLoader,
UnstructuredExcelLoader,
UnstructuredODTLoader,
UnstructuredMarkdownLoader,
UnstructuredPowerPointLoader,
UnstructuredRSTLoader,
UnstructuredXMLLoader,
@ -27,13 +25,13 @@ from open_webui.retrieval.loaders.external_document import ExternalDocumentLoade
from open_webui.retrieval.loaders.mistral import MistralLoader
from open_webui.retrieval.loaders.datalab_marker import DatalabMarkerLoader
from open_webui.retrieval.loaders.mineru import MinerULoader
from open_webui.env import GLOBAL_LOG_LEVEL
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
known_source_ext = [
"go",
@ -78,6 +76,7 @@ known_source_ext = [
"swift",
"vue",
"svelte",
"msg",
"ex",
"exs",
"erl",
@ -131,9 +130,8 @@ class TikaLoader:
class DoclingLoader:
def __init__(self, url, api_key=None, file_path=None, mime_type=None, params=None):
def __init__(self, url, file_path=None, mime_type=None, params=None):
self.url = url.rstrip("/")
self.api_key = api_key
self.file_path = file_path
self.mime_type = mime_type
@ -141,25 +139,36 @@ class DoclingLoader:
def load(self) -> list[Document]:
with open(self.file_path, "rb") as f:
headers = {}
if self.api_key:
headers["X-Api-Key"] = f"Bearer {self.api_key}"
files = {
"files": (
self.file_path,
f,
self.mime_type or "application/octet-stream",
)
}
r = requests.post(
f"{self.url}/v1/convert/file",
files={
"files": (
self.file_path,
f,
self.mime_type or "application/octet-stream",
params = {
"image_export_mode": "placeholder",
"table_mode": "accurate",
}
if self.params:
if self.params.get("do_picture_classification"):
params["do_picture_classification"] = self.params.get(
"do_picture_classification"
)
},
data={
"image_export_mode": "placeholder",
**self.params,
},
headers=headers,
)
if self.params.get("ocr_engine") and self.params.get("ocr_lang"):
params["ocr_engine"] = self.params.get("ocr_engine")
params["ocr_lang"] = [
lang.strip()
for lang in self.params.get("ocr_lang").split(",")
if lang.strip()
]
endpoint = f"{self.url}/v1alpha/convert/file"
r = requests.post(endpoint, files=files, data=params)
if r.ok:
result = r.json()
document_data = result.get("document", {})
@ -168,6 +177,7 @@ class DoclingLoader:
metadata = {"Content-Type": self.mime_type} if self.mime_type else {}
log.debug("Docling extracted text: %s", text)
return [Document(page_content=text, metadata=metadata)]
else:
error_msg = f"Error calling Docling API: {r.reason}"
@ -184,7 +194,6 @@ class DoclingLoader:
class Loader:
def __init__(self, engine: str = "", **kwargs):
self.engine = engine
self.user = kwargs.get("user", None)
self.kwargs = kwargs
def load(
@ -202,10 +211,7 @@ class Loader:
def _is_text_file(self, file_ext: str, file_content_type: str) -> bool:
return file_ext in known_source_ext or (
file_content_type
and file_content_type.find("text/") >= 0
# Avoid text/html files being detected as text
and not file_content_type.find("html") >= 0
file_content_type and file_content_type.find("text/") >= 0
)
def _get_loader(self, filename: str, file_content_type: str, file_path: str):
@ -221,7 +227,6 @@ class Loader:
url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"),
api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"),
mime_type=file_content_type,
user=self.user,
)
elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
if self._is_text_file(file_ext, file_content_type):
@ -230,6 +235,7 @@ class Loader:
loader = TikaLoader(
url=self.kwargs.get("TIKA_SERVER_URL"),
file_path=file_path,
mime_type=file_content_type,
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
)
elif (
@ -257,15 +263,10 @@ class Loader:
"tiff",
]
):
api_base_url = self.kwargs.get("DATALAB_MARKER_API_BASE_URL", "")
if not api_base_url or api_base_url.strip() == "":
api_base_url = "https://www.datalab.to/api/v1/marker" # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349
loader = DatalabMarkerLoader(
file_path=file_path,
api_key=self.kwargs["DATALAB_MARKER_API_KEY"],
api_base_url=api_base_url,
additional_config=self.kwargs.get("DATALAB_MARKER_ADDITIONAL_CONFIG"),
langs=self.kwargs.get("DATALAB_MARKER_LANGS"),
use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False),
skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False),
force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False),
@ -276,7 +277,6 @@ class Loader:
disable_image_extraction=self.kwargs.get(
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False
),
format_lines=self.kwargs.get("DATALAB_MARKER_FORMAT_LINES", False),
output_format=self.kwargs.get(
"DATALAB_MARKER_OUTPUT_FORMAT", "markdown"
),
@ -285,67 +285,38 @@ class Loader:
if self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
# Build params for DoclingLoader
params = self.kwargs.get("DOCLING_PARAMS", {})
if not isinstance(params, dict):
try:
params = json.loads(params)
except json.JSONDecodeError:
log.error("Invalid DOCLING_PARAMS format, expected JSON object")
params = {}
loader = DoclingLoader(
url=self.kwargs.get("DOCLING_SERVER_URL"),
api_key=self.kwargs.get("DOCLING_API_KEY", None),
file_path=file_path,
mime_type=file_content_type,
params=params,
params={
"ocr_engine": self.kwargs.get("DOCLING_OCR_ENGINE"),
"ocr_lang": self.kwargs.get("DOCLING_OCR_LANG"),
"do_picture_classification": self.kwargs.get(
"DOCLING_DO_PICTURE_DESCRIPTION"
),
},
)
elif (
self.engine == "document_intelligence"
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
and self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != ""
and (
file_ext in ["pdf", "docx", "ppt", "pptx"]
file_ext in ["pdf", "xls", "xlsx", "docx", "ppt", "pptx"]
or file_content_type
in [
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
]
)
):
if self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != "":
loader = AzureAIDocumentIntelligenceLoader(
file_path=file_path,
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
)
else:
loader = AzureAIDocumentIntelligenceLoader(
file_path=file_path,
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
azure_credential=DefaultAzureCredential(),
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
)
elif self.engine == "mineru" and file_ext in [
"pdf"
]: # MinerU currently only supports PDF
mineru_timeout = self.kwargs.get("MINERU_API_TIMEOUT", 300)
if mineru_timeout:
try:
mineru_timeout = int(mineru_timeout)
except ValueError:
mineru_timeout = 300
loader = MinerULoader(
loader = AzureAIDocumentIntelligenceLoader(
file_path=file_path,
api_mode=self.kwargs.get("MINERU_API_MODE", "local"),
api_url=self.kwargs.get("MINERU_API_URL", "http://localhost:8000"),
api_key=self.kwargs.get("MINERU_API_KEY", ""),
params=self.kwargs.get("MINERU_PARAMS", {}),
timeout=mineru_timeout,
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
)
elif (
self.engine == "mistral_ocr"
@ -354,9 +325,16 @@ class Loader:
in ["pdf"] # Mistral OCR currently only supports PDF and images
):
loader = MistralLoader(
base_url=self.kwargs.get("MISTRAL_OCR_API_BASE_URL"),
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"),
file_path=file_path,
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), file_path=file_path
)
elif (
self.engine == "external"
and self.kwargs.get("MISTRAL_OCR_API_KEY") != ""
and file_ext
in ["pdf"] # Mistral OCR currently only supports PDF and images
):
loader = MistralLoader(
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), file_path=file_path
)
else:
if file_ext == "pdf":
@ -393,8 +371,6 @@ class Loader:
loader = UnstructuredPowerPointLoader(file_path)
elif file_ext == "msg":
loader = OutlookMessageLoader(file_path)
elif file_ext == "odt":
loader = UnstructuredODTLoader(file_path)
elif self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
else:

View file

@ -1,524 +0,0 @@
import os
import time
import requests
import logging
import tempfile
import zipfile
from typing import List, Optional
from langchain_core.documents import Document
from fastapi import HTTPException, status
log = logging.getLogger(__name__)
class MinerULoader:
"""
MinerU document parser loader supporting both Cloud API and Local API modes.
Cloud API: Uses MinerU managed service with async task-based processing
Local API: Uses self-hosted MinerU API with synchronous processing
"""
def __init__(
self,
file_path: str,
api_mode: str = "local",
api_url: str = "http://localhost:8000",
api_key: str = "",
params: dict = None,
timeout: Optional[int] = 300,
):
self.file_path = file_path
self.api_mode = api_mode.lower()
self.api_url = api_url.rstrip("/")
self.api_key = api_key
self.timeout = timeout
# Parse params dict with defaults
self.params = params or {}
self.enable_ocr = params.get("enable_ocr", False)
self.enable_formula = params.get("enable_formula", True)
self.enable_table = params.get("enable_table", True)
self.language = params.get("language", "en")
self.model_version = params.get("model_version", "pipeline")
self.page_ranges = self.params.pop("page_ranges", "")
# Validate API mode
if self.api_mode not in ["local", "cloud"]:
raise ValueError(
f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'"
)
# Validate Cloud API requirements
if self.api_mode == "cloud" and not self.api_key:
raise ValueError("API key is required for Cloud API mode")
def load(self) -> List[Document]:
"""
Main entry point for loading and parsing the document.
Routes to Cloud or Local API based on api_mode.
"""
try:
if self.api_mode == "cloud":
return self._load_cloud_api()
else:
return self._load_local_api()
except Exception as e:
log.error(f"Error loading document with MinerU: {e}")
raise
def _load_local_api(self) -> List[Document]:
"""
Load document using Local API (synchronous).
Posts file to /file_parse endpoint and gets immediate response.
"""
log.info(f"Using MinerU Local API at {self.api_url}")
filename = os.path.basename(self.file_path)
# Build form data for Local API
form_data = {
**self.params,
"return_md": "true",
}
# Page ranges (Local API uses start_page_id and end_page_id)
if self.page_ranges:
# For simplicity, if page_ranges is specified, log a warning
# Full page range parsing would require parsing the string
log.warning(
f"Page ranges '{self.page_ranges}' specified but Local API uses different format. "
"Consider using start_page_id/end_page_id parameters if needed."
)
try:
with open(self.file_path, "rb") as f:
files = {"files": (filename, f, "application/octet-stream")}
log.info(f"Sending file to MinerU Local API: {filename}")
log.debug(f"Local API parameters: {form_data}")
response = requests.post(
f"{self.api_url}/file_parse",
data=form_data,
files=files,
timeout=self.timeout,
)
response.raise_for_status()
except FileNotFoundError:
raise HTTPException(
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
)
except requests.Timeout:
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT,
detail="MinerU Local API request timed out",
)
except requests.HTTPError as e:
error_detail = f"MinerU Local API request failed: {e}"
if e.response is not None:
try:
error_data = e.response.json()
error_detail += f" - {error_data}"
except:
error_detail += f" - {e.response.text}"
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error calling MinerU Local API: {str(e)}",
)
# Parse response
try:
result = response.json()
except ValueError as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid JSON response from MinerU Local API: {e}",
)
# Extract markdown content from response
if "results" not in result:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail="MinerU Local API response missing 'results' field",
)
results = result["results"]
if not results:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="MinerU returned empty results",
)
# Get the first (and typically only) result
file_result = list(results.values())[0]
markdown_content = file_result.get("md_content", "")
if not markdown_content:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="MinerU returned empty markdown content",
)
log.info(f"Successfully parsed document with MinerU Local API: {filename}")
# Create metadata
metadata = {
"source": filename,
"api_mode": "local",
"backend": result.get("backend", "unknown"),
"version": result.get("version", "unknown"),
}
return [Document(page_content=markdown_content, metadata=metadata)]
def _load_cloud_api(self) -> List[Document]:
"""
Load document using Cloud API (asynchronous).
Uses batch upload endpoint to avoid need for public file URLs.
"""
log.info(f"Using MinerU Cloud API at {self.api_url}")
filename = os.path.basename(self.file_path)
# Step 1: Request presigned upload URL
batch_id, upload_url = self._request_upload_url(filename)
# Step 2: Upload file to presigned URL
self._upload_to_presigned_url(upload_url)
# Step 3: Poll for results
result = self._poll_batch_status(batch_id, filename)
# Step 4: Download and extract markdown from ZIP
markdown_content = self._download_and_extract_zip(
result["full_zip_url"], filename
)
log.info(f"Successfully parsed document with MinerU Cloud API: {filename}")
# Create metadata
metadata = {
"source": filename,
"api_mode": "cloud",
"batch_id": batch_id,
}
return [Document(page_content=markdown_content, metadata=metadata)]
def _request_upload_url(self, filename: str) -> tuple:
"""
Request presigned upload URL from Cloud API.
Returns (batch_id, upload_url).
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
# Build request body
request_body = {
**self.params,
"files": [
{
"name": filename,
"is_ocr": self.enable_ocr,
}
],
}
# Add page ranges if specified
if self.page_ranges:
request_body["files"][0]["page_ranges"] = self.page_ranges
log.info(f"Requesting upload URL for: {filename}")
log.debug(f"Cloud API request body: {request_body}")
try:
response = requests.post(
f"{self.api_url}/file-urls/batch",
headers=headers,
json=request_body,
timeout=30,
)
response.raise_for_status()
except requests.HTTPError as e:
error_detail = f"Failed to request upload URL: {e}"
if e.response is not None:
try:
error_data = e.response.json()
error_detail += f" - {error_data.get('msg', error_data)}"
except:
error_detail += f" - {e.response.text}"
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error requesting upload URL: {str(e)}",
)
try:
result = response.json()
except ValueError as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid JSON response: {e}",
)
# Check for API error response
if result.get("code") != 0:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}",
)
data = result.get("data", {})
batch_id = data.get("batch_id")
file_urls = data.get("file_urls", [])
if not batch_id or not file_urls:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail="MinerU Cloud API response missing batch_id or file_urls",
)
upload_url = file_urls[0]
log.info(f"Received upload URL for batch: {batch_id}")
return batch_id, upload_url
def _upload_to_presigned_url(self, upload_url: str) -> None:
"""
Upload file to presigned URL (no authentication needed).
"""
log.info(f"Uploading file to presigned URL")
try:
with open(self.file_path, "rb") as f:
response = requests.put(
upload_url,
data=f,
timeout=self.timeout,
)
response.raise_for_status()
except FileNotFoundError:
raise HTTPException(
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
)
except requests.Timeout:
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT,
detail="File upload to presigned URL timed out",
)
except requests.HTTPError as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Failed to upload file to presigned URL: {e}",
)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error uploading file: {str(e)}",
)
log.info("File uploaded successfully")
def _poll_batch_status(self, batch_id: str, filename: str) -> dict:
"""
Poll batch status until completion.
Returns the result dict for the file.
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
}
max_iterations = 300 # 10 minutes max (2 seconds per iteration)
poll_interval = 2 # seconds
log.info(f"Polling batch status: {batch_id}")
for iteration in range(max_iterations):
try:
response = requests.get(
f"{self.api_url}/extract-results/batch/{batch_id}",
headers=headers,
timeout=30,
)
response.raise_for_status()
except requests.HTTPError as e:
error_detail = f"Failed to poll batch status: {e}"
if e.response is not None:
try:
error_data = e.response.json()
error_detail += f" - {error_data.get('msg', error_data)}"
except:
error_detail += f" - {e.response.text}"
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error polling batch status: {str(e)}",
)
try:
result = response.json()
except ValueError as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid JSON response while polling: {e}",
)
# Check for API error response
if result.get("code") != 0:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}",
)
data = result.get("data", {})
extract_result = data.get("extract_result", [])
# Find our file in the batch results
file_result = None
for item in extract_result:
if item.get("file_name") == filename:
file_result = item
break
if not file_result:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"File {filename} not found in batch results",
)
state = file_result.get("state")
if state == "done":
log.info(f"Processing complete for {filename}")
return file_result
elif state == "failed":
error_msg = file_result.get("err_msg", "Unknown error")
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"MinerU processing failed: {error_msg}",
)
elif state in ["waiting-file", "pending", "running", "converting"]:
# Still processing
if iteration % 10 == 0: # Log every 20 seconds
log.info(
f"Processing status: {state} (iteration {iteration + 1}/{max_iterations})"
)
time.sleep(poll_interval)
else:
log.warning(f"Unknown state: {state}")
time.sleep(poll_interval)
# Timeout
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT,
detail="MinerU processing timed out after 10 minutes",
)
def _download_and_extract_zip(self, zip_url: str, filename: str) -> str:
"""
Download ZIP file from CDN and extract markdown content.
Returns the markdown content as a string.
"""
log.info(f"Downloading results from: {zip_url}")
try:
response = requests.get(zip_url, timeout=60)
response.raise_for_status()
except requests.HTTPError as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Failed to download results ZIP: {e}",
)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error downloading results: {str(e)}",
)
# Save ZIP to temporary file and extract
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
tmp_zip.write(response.content)
tmp_zip_path = tmp_zip.name
with tempfile.TemporaryDirectory() as tmp_dir:
# Extract ZIP
with zipfile.ZipFile(tmp_zip_path, "r") as zip_ref:
zip_ref.extractall(tmp_dir)
# Find markdown file - search recursively for any .md file
markdown_content = None
found_md_path = None
# First, list all files in the ZIP for debugging
all_files = []
for root, dirs, files in os.walk(tmp_dir):
for file in files:
full_path = os.path.join(root, file)
all_files.append(full_path)
# Look for any .md file
if file.endswith(".md"):
found_md_path = full_path
log.info(f"Found markdown file at: {full_path}")
try:
with open(full_path, "r", encoding="utf-8") as f:
markdown_content = f.read()
if (
markdown_content
): # Use the first non-empty markdown file
break
except Exception as e:
log.warning(f"Failed to read {full_path}: {e}")
if markdown_content:
break
if markdown_content is None:
log.error(f"Available files in ZIP: {all_files}")
# Try to provide more helpful error message
md_files = [f for f in all_files if f.endswith(".md")]
if md_files:
error_msg = (
f"Found .md files but couldn't read them: {md_files}"
)
else:
error_msg = (
f"No .md files found in ZIP. Available files: {all_files}"
)
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=error_msg,
)
# Clean up temporary ZIP file
os.unlink(tmp_zip_path)
except zipfile.BadZipFile as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid ZIP file received: {e}",
)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error extracting ZIP: {str(e)}",
)
if not markdown_content:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="Extracted markdown content is empty",
)
log.info(
f"Successfully extracted markdown content ({len(markdown_content)} characters)"
)
return markdown_content

View file

@ -9,29 +9,23 @@ from typing import List, Dict, Any
from contextlib import asynccontextmanager
from langchain_core.documents import Document
from open_webui.env import GLOBAL_LOG_LEVEL
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class MistralLoader:
"""
Enhanced Mistral OCR loader with both sync and async support.
Loads documents by processing them through the Mistral OCR API.
Performance Optimizations:
- Differentiated timeouts for different operations
- Intelligent retry logic with exponential backoff
- Memory-efficient file streaming for large files
- Connection pooling and keepalive optimization
- Semaphore-based concurrency control for batch processing
- Enhanced error handling with retryable error classification
"""
BASE_API_URL = "https://api.mistral.ai/v1"
def __init__(
self,
base_url: str,
api_key: str,
file_path: str,
timeout: int = 300, # 5 minutes default
@ -53,49 +47,23 @@ class MistralLoader:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found at {file_path}")
self.base_url = (
base_url.rstrip("/") if base_url else "https://api.mistral.ai/v1"
)
self.api_key = api_key
self.file_path = file_path
self.timeout = timeout
self.max_retries = max_retries
self.debug = enable_debug_logging
# PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations
# This prevents long-running OCR operations from affecting quick operations
# and improves user experience by failing fast on operations that should be quick
self.upload_timeout = min(
timeout, 120
) # Cap upload at 2 minutes - prevents hanging on large files
self.url_timeout = (
30 # URL requests should be fast - fail quickly if API is slow
)
self.ocr_timeout = (
timeout # OCR can take the full timeout - this is the heavy operation
)
self.cleanup_timeout = (
30 # Cleanup should be quick - don't hang on file deletion
)
# PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls
# This avoids multiple os.path.basename() and os.path.getsize() calls during processing
# Pre-compute file info for performance
self.file_name = os.path.basename(file_path)
self.file_size = os.path.getsize(file_path)
# ENHANCEMENT: Added User-Agent for better API tracking and debugging
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"User-Agent": "OpenWebUI-MistralLoader/2.0", # Helps API provider track usage
"User-Agent": "OpenWebUI-MistralLoader/2.0",
}
def _debug_log(self, message: str, *args) -> None:
"""
PERFORMANCE OPTIMIZATION: Conditional debug logging for performance.
Only processes debug messages when debug mode is enabled, avoiding
string formatting overhead in production environments.
"""
"""Conditional debug logging for performance."""
if self.debug:
log.debug(message, *args)
@ -147,118 +115,53 @@ class MistralLoader:
log.error(f"Unexpected error processing response: {e}")
raise
def _is_retryable_error(self, error: Exception) -> bool:
"""
ENHANCEMENT: Intelligent error classification for retry logic.
Determines if an error is retryable based on its type and status code.
This prevents wasting time retrying errors that will never succeed
(like authentication errors) while ensuring transient errors are retried.
Retryable errors:
- Network connection errors (temporary network issues)
- Timeouts (server might be temporarily overloaded)
- Server errors (5xx status codes - server-side issues)
- Rate limiting (429 status - temporary throttling)
Non-retryable errors:
- Authentication errors (401, 403 - won't fix with retry)
- Bad request errors (400 - malformed request)
- Not found errors (404 - resource doesn't exist)
"""
if isinstance(error, requests.exceptions.ConnectionError):
return True # Network issues are usually temporary
if isinstance(error, requests.exceptions.Timeout):
return True # Timeouts might resolve on retry
if isinstance(error, requests.exceptions.HTTPError):
# Only retry on server errors (5xx) or rate limits (429)
if hasattr(error, "response") and error.response is not None:
status_code = error.response.status_code
return status_code >= 500 or status_code == 429
return False
if isinstance(
error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)
):
return True # Async network/timeout errors are retryable
if isinstance(error, aiohttp.ClientResponseError):
return error.status >= 500 or error.status == 429
return False # All other errors are non-retryable
def _retry_request_sync(self, request_func, *args, **kwargs):
"""
ENHANCEMENT: Synchronous retry logic with intelligent error classification.
Uses exponential backoff with jitter to avoid thundering herd problems.
The wait time increases exponentially but is capped at 30 seconds to
prevent excessive delays. Only retries errors that are likely to succeed
on subsequent attempts.
"""
"""Synchronous retry logic with exponential backoff."""
for attempt in range(self.max_retries):
try:
return request_func(*args, **kwargs)
except Exception as e:
if attempt == self.max_retries - 1 or not self._is_retryable_error(e):
except (requests.exceptions.RequestException, Exception) as e:
if attempt == self.max_retries - 1:
raise
# PERFORMANCE OPTIMIZATION: Exponential backoff with cap
# Prevents overwhelming the server while ensuring reasonable retry delays
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
wait_time = (2**attempt) + 0.5
log.warning(
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
f"Retrying in {wait_time}s..."
f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
)
time.sleep(wait_time)
async def _retry_request_async(self, request_func, *args, **kwargs):
"""
ENHANCEMENT: Async retry logic with intelligent error classification.
Async version of retry logic that doesn't block the event loop during
wait periods. Uses the same exponential backoff strategy as sync version.
"""
"""Async retry logic with exponential backoff."""
for attempt in range(self.max_retries):
try:
return await request_func(*args, **kwargs)
except Exception as e:
if attempt == self.max_retries - 1 or not self._is_retryable_error(e):
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
if attempt == self.max_retries - 1:
raise
# PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
wait_time = (2**attempt) + 0.5
log.warning(
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
f"Retrying in {wait_time}s..."
f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
)
await asyncio.sleep(wait_time) # Non-blocking wait
await asyncio.sleep(wait_time)
def _upload_file(self) -> str:
"""
PERFORMANCE OPTIMIZATION: Enhanced file upload with streaming consideration.
Uploads the file to Mistral for OCR processing (sync version).
Uses context manager for file handling to ensure proper resource cleanup.
Although streaming is not enabled for this endpoint, the file is opened
in a context manager to minimize memory usage duration.
"""
"""Uploads the file to Mistral for OCR processing (sync version)."""
log.info("Uploading file to Mistral API")
url = f"{self.base_url}/files"
url = f"{self.BASE_API_URL}/files"
file_name = os.path.basename(self.file_path)
def upload_request():
# MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime
# This ensures the file is closed immediately after reading, reducing memory usage
with open(self.file_path, "rb") as f:
files = {"file": (self.file_name, f, "application/pdf")}
files = {"file": (file_name, f, "application/pdf")}
data = {"purpose": "ocr"}
# NOTE: stream=False is required for this endpoint
# The Mistral API doesn't support chunked uploads for this endpoint
response = requests.post(
url,
headers=self.headers,
files=files,
data=data,
timeout=self.upload_timeout, # Use specialized upload timeout
stream=False, # Keep as False for this endpoint
timeout=self.timeout,
)
return self._handle_response(response)
@ -276,7 +179,7 @@ class MistralLoader:
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
"""Async file upload with streaming for better memory efficiency."""
url = f"{self.base_url}/files"
url = f"{self.BASE_API_URL}/files"
async def upload_request():
# Create multipart writer for streaming upload
@ -306,7 +209,7 @@ class MistralLoader:
url,
data=writer,
headers=self.headers,
timeout=aiohttp.ClientTimeout(total=self.upload_timeout),
timeout=aiohttp.ClientTimeout(total=self.timeout),
) as response:
return await self._handle_response_async(response)
@ -322,13 +225,13 @@ class MistralLoader:
def _get_signed_url(self, file_id: str) -> str:
"""Retrieves a temporary signed URL for the uploaded file (sync version)."""
log.info(f"Getting signed URL for file ID: {file_id}")
url = f"{self.base_url}/files/{file_id}/url"
url = f"{self.BASE_API_URL}/files/{file_id}/url"
params = {"expiry": 1}
signed_url_headers = {**self.headers, "Accept": "application/json"}
def url_request():
response = requests.get(
url, headers=signed_url_headers, params=params, timeout=self.url_timeout
url, headers=signed_url_headers, params=params, timeout=self.timeout
)
return self._handle_response(response)
@ -347,7 +250,7 @@ class MistralLoader:
self, session: aiohttp.ClientSession, file_id: str
) -> str:
"""Async signed URL retrieval."""
url = f"{self.base_url}/files/{file_id}/url"
url = f"{self.BASE_API_URL}/files/{file_id}/url"
params = {"expiry": 1}
headers = {**self.headers, "Accept": "application/json"}
@ -358,7 +261,7 @@ class MistralLoader:
url,
headers=headers,
params=params,
timeout=aiohttp.ClientTimeout(total=self.url_timeout),
timeout=aiohttp.ClientTimeout(total=self.timeout),
) as response:
return await self._handle_response_async(response)
@ -374,7 +277,7 @@ class MistralLoader:
def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
"""Sends the signed URL to the OCR endpoint for processing (sync version)."""
log.info("Processing OCR via Mistral API")
url = f"{self.base_url}/ocr"
url = f"{self.BASE_API_URL}/ocr"
ocr_headers = {
**self.headers,
"Content-Type": "application/json",
@ -391,7 +294,7 @@ class MistralLoader:
def ocr_request():
response = requests.post(
url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout
url, headers=ocr_headers, json=payload, timeout=self.timeout
)
return self._handle_response(response)
@ -408,7 +311,7 @@ class MistralLoader:
self, session: aiohttp.ClientSession, signed_url: str
) -> Dict[str, Any]:
"""Async OCR processing with timing metrics."""
url = f"{self.base_url}/ocr"
url = f"{self.BASE_API_URL}/ocr"
headers = {
**self.headers,
@ -433,7 +336,7 @@ class MistralLoader:
url,
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=self.ocr_timeout),
timeout=aiohttp.ClientTimeout(total=self.timeout),
) as response:
ocr_response = await self._handle_response_async(response)
@ -447,12 +350,10 @@ class MistralLoader:
def _delete_file(self, file_id: str) -> None:
"""Deletes the file from Mistral storage (sync version)."""
log.info(f"Deleting uploaded file ID: {file_id}")
url = f"{self.base_url}/files/{file_id}"
url = f"{self.BASE_API_URL}/files/{file_id}"
try:
response = requests.delete(
url, headers=self.headers, timeout=self.cleanup_timeout
)
response = requests.delete(url, headers=self.headers, timeout=30)
delete_response = self._handle_response(response)
log.info(f"File deleted successfully: {delete_response}")
except Exception as e:
@ -468,10 +369,10 @@ class MistralLoader:
async def delete_request():
self._debug_log(f"Deleting file ID: {file_id}")
async with session.delete(
url=f"{self.base_url}/files/{file_id}",
url=f"{self.BASE_API_URL}/files/{file_id}",
headers=self.headers,
timeout=aiohttp.ClientTimeout(
total=self.cleanup_timeout
total=30
), # Shorter timeout for cleanup
) as response:
return await self._handle_response_async(response)
@ -487,40 +388,29 @@ class MistralLoader:
async def _get_session(self):
"""Context manager for HTTP session with optimized settings."""
connector = aiohttp.TCPConnector(
limit=20, # Increased total connection limit for better throughput
limit_per_host=10, # Increased per-host limit for API endpoints
ttl_dns_cache=600, # Longer DNS cache TTL (10 minutes)
limit=10, # Total connection limit
limit_per_host=5, # Per-host connection limit
ttl_dns_cache=300, # DNS cache TTL
use_dns_cache=True,
keepalive_timeout=60, # Increased keepalive for connection reuse
keepalive_timeout=30,
enable_cleanup_closed=True,
force_close=False, # Allow connection reuse
resolver=aiohttp.AsyncResolver(), # Use async DNS resolver
)
timeout = aiohttp.ClientTimeout(
total=self.timeout,
connect=30, # Connection timeout
sock_read=60, # Socket read timeout
)
async with aiohttp.ClientSession(
connector=connector,
timeout=timeout,
timeout=aiohttp.ClientTimeout(total=self.timeout),
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
raise_for_status=False, # We handle status codes manually
trust_env=True,
) as session:
yield session
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
"""Process OCR results into Document objects with enhanced metadata and memory efficiency."""
"""Process OCR results into Document objects with enhanced metadata."""
pages_data = ocr_response.get("pages")
if not pages_data:
log.warning("No pages found in OCR response.")
return [
Document(
page_content="No text content found",
metadata={"error": "no_pages", "file_name": self.file_name},
page_content="No text content found", metadata={"error": "no_pages"}
)
]
@ -528,44 +418,41 @@ class MistralLoader:
total_pages = len(pages_data)
skipped_pages = 0
# Process pages in a memory-efficient way
for page_data in pages_data:
page_content = page_data.get("markdown")
page_index = page_data.get("index") # API uses 0-based index
if page_content is None or page_index is None:
if page_content is not None and page_index is not None:
# Clean up content efficiently
cleaned_content = (
page_content.strip()
if isinstance(page_content, str)
else str(page_content)
)
if cleaned_content: # Only add non-empty pages
documents.append(
Document(
page_content=cleaned_content,
metadata={
"page": page_index, # 0-based index from API
"page_label": page_index
+ 1, # 1-based label for convenience
"total_pages": total_pages,
"file_name": self.file_name,
"file_size": self.file_size,
"processing_engine": "mistral-ocr",
},
)
)
else:
skipped_pages += 1
self._debug_log(f"Skipping empty page {page_index}")
else:
skipped_pages += 1
self._debug_log(
f"Skipping page due to missing 'markdown' or 'index'. Data keys: {list(page_data.keys())}"
f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}"
)
continue
# Clean up content efficiently with early exit for empty content
if isinstance(page_content, str):
cleaned_content = page_content.strip()
else:
cleaned_content = str(page_content).strip()
if not cleaned_content:
skipped_pages += 1
self._debug_log(f"Skipping empty page {page_index}")
continue
# Create document with optimized metadata
documents.append(
Document(
page_content=cleaned_content,
metadata={
"page": page_index, # 0-based index from API
"page_label": page_index + 1, # 1-based label for convenience
"total_pages": total_pages,
"file_name": self.file_name,
"file_size": self.file_size,
"processing_engine": "mistral-ocr",
"content_length": len(cleaned_content),
},
)
)
if skipped_pages > 0:
log.info(
@ -580,11 +467,7 @@ class MistralLoader:
return [
Document(
page_content="No valid text content found in document",
metadata={
"error": "no_valid_pages",
"total_pages": total_pages,
"file_name": self.file_name,
},
metadata={"error": "no_valid_pages", "total_pages": total_pages},
)
]
@ -702,14 +585,12 @@ class MistralLoader:
@staticmethod
async def load_multiple_async(
loaders: List["MistralLoader"],
max_concurrent: int = 5, # Limit concurrent requests
) -> List[List[Document]]:
"""
Process multiple files concurrently with controlled concurrency.
Process multiple files concurrently for maximum performance.
Args:
loaders: List of MistralLoader instances
max_concurrent: Maximum number of concurrent requests
Returns:
List of document lists, one for each loader
@ -717,20 +598,11 @@ class MistralLoader:
if not loaders:
return []
log.info(
f"Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent"
)
log.info(f"Starting concurrent processing of {len(loaders)} files")
start_time = time.time()
# Use semaphore to control concurrency
semaphore = asyncio.Semaphore(max_concurrent)
async def process_with_semaphore(loader: "MistralLoader") -> List[Document]:
async with semaphore:
return await loader.load_async()
# Process all files with controlled concurrency
tasks = [process_with_semaphore(loader) for loader in loaders]
# Process all files concurrently
tasks = [loader.load_async() for loader in loaders]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Handle any exceptions in results
@ -752,18 +624,10 @@ class MistralLoader:
else:
processed_results.append(result)
# MONITORING: Log comprehensive batch processing statistics
total_time = time.time() - start_time
total_docs = sum(len(docs) for docs in processed_results)
success_count = sum(
1 for result in results if not isinstance(result, Exception)
)
failure_count = len(results) - success_count
log.info(
f"Batch processing completed in {total_time:.2f}s: "
f"{success_count} files succeeded, {failure_count} files failed, "
f"produced {total_docs} total documents"
f"Batch processing completed in {total_time:.2f}s, produced {total_docs} total documents"
)
return processed_results

View file

@ -4,8 +4,10 @@ from typing import Iterator, List, Literal, Union
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class TavilyLoader(BaseLoader):

View file

@ -1,11 +1,12 @@
import logging
from xml.etree.ElementTree import ParseError
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
from urllib.parse import parse_qs, urlparse
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
ALLOWED_SCHEMES = {"http", "https"}
ALLOWED_NETLOCS = {
@ -81,7 +82,6 @@ class YoutubeLoader:
TranscriptsDisabled,
YouTubeTranscriptApi,
)
from youtube_transcript_api.proxies import GenericProxyConfig
except ImportError:
raise ImportError(
'Could not import "youtube_transcript_api" Python package. '
@ -89,16 +89,19 @@ class YoutubeLoader:
)
if self.proxy_url:
youtube_proxies = GenericProxyConfig(
http_url=self.proxy_url, https_url=self.proxy_url
)
youtube_proxies = {
"http": self.proxy_url,
"https": self.proxy_url,
}
# Don't log complete URL because it might contain secrets
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
else:
youtube_proxies = None
transcript_api = YouTubeTranscriptApi(proxy_config=youtube_proxies)
try:
transcript_list = transcript_api.list(self.video_id)
transcript_list = YouTubeTranscriptApi.list_transcripts(
self.video_id, proxies=youtube_proxies
)
except Exception as e:
log.exception("Loading YouTube transcript failed")
return []
@ -107,37 +110,11 @@ class YoutubeLoader:
for lang in self.language:
try:
transcript = transcript_list.find_transcript([lang])
if transcript.is_generated:
log.debug(f"Found generated transcript for language '{lang}'")
try:
transcript = transcript_list.find_manually_created_transcript(
[lang]
)
log.debug(f"Found manual transcript for language '{lang}'")
except NoTranscriptFound:
log.debug(
f"No manual transcript found for language '{lang}', using generated"
)
pass
log.debug(f"Found transcript for language '{lang}'")
try:
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
except ParseError:
log.debug(f"Empty or invalid transcript for language '{lang}'")
continue
if not transcript_pieces:
log.debug(f"Empty transcript for language '{lang}'")
continue
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
transcript_text = " ".join(
map(
lambda transcript_piece: (
transcript_piece.text.strip(" ")
if hasattr(transcript_piece, "text")
else ""
),
lambda transcript_piece: transcript_piece.text.strip(" "),
transcript_pieces,
)
)
@ -154,11 +131,6 @@ class YoutubeLoader:
log.warning(
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
)
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))
async def aload(self) -> Generator[Document, None, None]:
"""Asynchronously load YouTube transcripts into `Document` objects."""
import asyncio
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.load)
raise NoTranscriptFound(
f"No transcript found for any supported language. Verify if the video has transcripts, add more languages if needed."
)

View file

@ -5,10 +5,12 @@ import numpy as np
from colbert.infra import ColBERTConfig
from colbert.modeling.checkpoint import Checkpoint
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.models.base_reranker import BaseReranker
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ColBERT(BaseReranker):

View file

@ -1,15 +1,13 @@
import logging
import requests
from typing import Optional, List, Tuple
from urllib.parse import quote
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.models.base_reranker import BaseReranker
from open_webui.utils.headers import include_user_info_headers
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ExternalReranker(BaseReranker):
@ -18,16 +16,12 @@ class ExternalReranker(BaseReranker):
api_key: str,
url: str = "http://localhost:8080/v1/rerank",
model: str = "reranker",
timeout: Optional[int] = None,
):
self.api_key = api_key
self.url = url
self.model = model
self.timeout = timeout
def predict(
self, sentences: List[Tuple[str, str]], user=None
) -> Optional[List[float]]:
def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
query = sentences[0][0]
docs = [i[1] for i in sentences]
@ -42,19 +36,13 @@ class ExternalReranker(BaseReranker):
log.info(f"ExternalReranker:predict:model {self.model}")
log.info(f"ExternalReranker:predict:query {query}")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.post(
f"{self.url}",
headers=headers,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
},
json=payload,
timeout=self.timeout,
)
r.raise_for_status()

File diff suppressed because it is too large Load diff

View file

@ -11,8 +11,6 @@ from open_webui.retrieval.vector.main import (
SearchResult,
GetResult,
)
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.config import (
CHROMA_DATA_PATH,
CHROMA_HTTP_HOST,
@ -24,8 +22,10 @@ from open_webui.config import (
CHROMA_CLIENT_AUTH_PROVIDER,
CHROMA_CLIENT_AUTH_CREDENTIALS,
)
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ChromaClient(VectorDBBase):
@ -144,7 +144,7 @@ class ChromaClient(VectorDBBase):
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items]
metadatas = [process_metadata(item["metadata"]) for item in items]
metadatas = [item["metadata"] for item in items]
for batch in create_batches(
api=self.client,
@ -164,7 +164,7 @@ class ChromaClient(VectorDBBase):
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items]
metadatas = [process_metadata(item["metadata"]) for item in items]
metadatas = [item["metadata"] for item in items]
collection.upsert(
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas

View file

@ -2,8 +2,6 @@ from elasticsearch import Elasticsearch, BadRequestError
from typing import Optional
import ssl
from elasticsearch.helpers import bulk, scan
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
@ -245,7 +243,7 @@ class ElasticsearchClient(VectorDBBase):
"collection": collection_name,
"vector": item["vector"],
"text": item["text"],
"metadata": process_metadata(item["metadata"]),
"metadata": item["metadata"],
},
}
for item in batch
@ -266,7 +264,7 @@ class ElasticsearchClient(VectorDBBase):
"collection": collection_name,
"vector": item["vector"],
"text": item["text"],
"metadata": process_metadata(item["metadata"]),
"metadata": item["metadata"],
},
"doc_as_upsert": True,
}

View file

@ -1,12 +1,8 @@
from pymilvus import MilvusClient as Client
from pymilvus import FieldSchema, DataType
from pymilvus import connections, Collection
import json
import logging
from typing import Optional
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
@ -22,11 +18,11 @@ from open_webui.config import (
MILVUS_HNSW_M,
MILVUS_HNSW_EFCONSTRUCTION,
MILVUS_IVF_FLAT_NLIST,
MILVUS_DISKANN_MAX_DEGREE,
MILVUS_DISKANN_SEARCH_LIST_SIZE,
)
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class MilvusClient(VectorDBBase):
@ -131,18 +127,12 @@ class MilvusClient(VectorDBBase):
elif index_type == "IVF_FLAT":
index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST}
log.info(f"IVF_FLAT params: {index_creation_params}")
elif index_type == "DISKANN":
index_creation_params = {
"max_degree": MILVUS_DISKANN_MAX_DEGREE,
"search_list_size": MILVUS_DISKANN_SEARCH_LIST_SIZE,
}
log.info(f"DISKANN params: {index_creation_params}")
elif index_type in ["FLAT", "AUTOINDEX"]:
log.info(f"Using {index_type} index with no specific build-time params.")
else:
log.warning(
f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
f"Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. "
f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX. "
f"Milvus will use its default for the collection if this type is not directly supported for index creation."
)
# For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
@ -195,54 +185,86 @@ class MilvusClient(VectorDBBase):
)
return self._result_to_search_result(result)
def query(self, collection_name: str, filter: dict, limit: int = -1):
connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB)
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
# Construct the filter string for querying
collection_name = collection_name.replace("-", "_")
if not self.has_collection(collection_name):
log.warning(
f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
)
return None
filter_string = " && ".join(
[
f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items()
]
)
max_limit = 16383 # The maximum number of records per request
all_results = []
if limit is None:
# Milvus default limit for query if not specified is 16384, but docs mention iteration.
# Let's set a practical high number if "all" is intended, or handle true pagination.
# For now, if limit is None, we'll fetch in batches up to a very large number.
# This part could be refined based on expected use cases for "get all".
# For this function signature, None implies "as many as possible" up to Milvus limits.
limit = (
16384 * 10
) # A large number to signify fetching many, will be capped by actual data or max_limit per call.
log.info(
f"Limit not specified for query, fetching up to {limit} results in batches."
)
filter_expressions = []
for key, value in filter.items():
if isinstance(value, str):
filter_expressions.append(f'metadata["{key}"] == "{value}"')
else:
filter_expressions.append(f'metadata["{key}"] == {value}')
filter_string = " && ".join(filter_expressions)
collection = Collection(f"{self.collection_prefix}_{collection_name}")
collection.load()
# Initialize offset and remaining to handle pagination
offset = 0
remaining = limit
try:
log.info(
f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}"
)
# Loop until there are no more items to fetch or the desired limit is reached
while remaining > 0:
current_fetch = min(
max_limit, remaining if isinstance(remaining, int) else max_limit
)
log.debug(
f"Querying with offset: {offset}, current_fetch: {current_fetch}"
)
iterator = collection.query_iterator(
expr=filter_string,
output_fields=[
"id",
"data",
"metadata",
],
limit=limit if limit > 0 else -1,
)
results = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}",
filter=filter_string,
output_fields=[
"id",
"data",
"metadata",
], # Explicitly list needed fields. Vector not usually needed in query.
limit=current_fetch,
offset=offset,
)
all_results = []
while True:
batch = iterator.next()
if not batch:
iterator.close()
if not results:
log.debug("No more results from query.")
break
all_results.extend(batch)
log.debug(f"Total results from query: {len(all_results)}")
return self._result_to_get_result([all_results] if all_results else [[]])
all_results.extend(results)
results_count = len(results)
log.debug(f"Fetched {results_count} results in this batch.")
if isinstance(remaining, int):
remaining -= results_count
offset += results_count
# Break the loop if the results returned are less than the requested fetch count (means end of data)
if results_count < current_fetch:
log.debug(
"Fetched less than requested, assuming end of results for this query."
)
break
log.info(f"Total results from query: {len(all_results)}")
return self._result_to_get_result([all_results])
except Exception as e:
log.exception(
f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"
@ -257,7 +279,7 @@ class MilvusClient(VectorDBBase):
)
# Using query with a trivial filter to get all items.
# This will use the paginated query logic.
return self.query(collection_name=collection_name, filter={}, limit=-1)
return self.query(collection_name=collection_name, filter={}, limit=None)
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
@ -289,7 +311,7 @@ class MilvusClient(VectorDBBase):
"id": item["id"],
"vector": item["vector"],
"data": {"text": item["text"]},
"metadata": process_metadata(item["metadata"]),
"metadata": item["metadata"],
}
for item in items
],
@ -325,7 +347,7 @@ class MilvusClient(VectorDBBase):
"id": item["id"],
"vector": item["vector"],
"data": {"text": item["text"]},
"metadata": process_metadata(item["metadata"]),
"metadata": item["metadata"],
}
for item in items
],

View file

@ -1,287 +0,0 @@
import logging
from typing import Optional, Tuple, List, Dict, Any
from open_webui.config import (
MILVUS_URI,
MILVUS_TOKEN,
MILVUS_DB,
MILVUS_COLLECTION_PREFIX,
MILVUS_INDEX_TYPE,
MILVUS_METRIC_TYPE,
MILVUS_HNSW_M,
MILVUS_HNSW_EFCONSTRUCTION,
MILVUS_IVF_FLAT_NLIST,
)
from open_webui.retrieval.vector.main import (
GetResult,
SearchResult,
VectorDBBase,
VectorItem,
)
from pymilvus import (
connections,
utility,
Collection,
CollectionSchema,
FieldSchema,
DataType,
)
log = logging.getLogger(__name__)
RESOURCE_ID_FIELD = "resource_id"
class MilvusClient(VectorDBBase):
def __init__(self):
# Milvus collection names can only contain numbers, letters, and underscores.
self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_")
connections.connect(
alias="default",
uri=MILVUS_URI,
token=MILVUS_TOKEN,
db_name=MILVUS_DB,
)
# Main collection types for multi-tenancy
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
self.FILE_COLLECTION = f"{self.collection_prefix}_files"
self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search"
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based"
self.shared_collections = [
self.MEMORY_COLLECTION,
self.KNOWLEDGE_COLLECTION,
self.FILE_COLLECTION,
self.WEB_SEARCH_COLLECTION,
self.HASH_BASED_COLLECTION,
]
def _get_collection_and_resource_id(self, collection_name: str) -> Tuple[str, str]:
"""
Maps the traditional collection name to multi-tenant collection and resource ID.
WARNING: This mapping relies on current Open WebUI naming conventions for
collection names. If Open WebUI changes how it generates collection names
(e.g., "user-memory-" prefix, "file-" prefix, web search patterns, or hash
formats), this mapping will break and route data to incorrect collections.
POTENTIALLY CAUSING HUGE DATA CORRUPTION, DATA CONSISTENCY ISSUES AND INCORRECT
DATA MAPPING INSIDE THE DATABASE.
"""
resource_id = collection_name
if collection_name.startswith("user-memory-"):
return self.MEMORY_COLLECTION, resource_id
elif collection_name.startswith("file-"):
return self.FILE_COLLECTION, resource_id
elif collection_name.startswith("web-search-"):
return self.WEB_SEARCH_COLLECTION, resource_id
elif len(collection_name) == 63 and all(
c in "0123456789abcdef" for c in collection_name
):
return self.HASH_BASED_COLLECTION, resource_id
else:
return self.KNOWLEDGE_COLLECTION, resource_id
def _create_shared_collection(self, mt_collection_name: str, dimension: int):
fields = [
FieldSchema(
name="id",
dtype=DataType.VARCHAR,
is_primary=True,
auto_id=False,
max_length=36,
),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="metadata", dtype=DataType.JSON),
FieldSchema(name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255),
]
schema = CollectionSchema(fields, "Shared collection for multi-tenancy")
collection = Collection(mt_collection_name, schema)
index_params = {
"metric_type": MILVUS_METRIC_TYPE,
"index_type": MILVUS_INDEX_TYPE,
"params": {},
}
if MILVUS_INDEX_TYPE == "HNSW":
index_params["params"] = {
"M": MILVUS_HNSW_M,
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
}
elif MILVUS_INDEX_TYPE == "IVF_FLAT":
index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST}
collection.create_index("vector", index_params)
collection.create_index(RESOURCE_ID_FIELD)
log.info(f"Created shared collection: {mt_collection_name}")
return collection
def _ensure_collection(self, mt_collection_name: str, dimension: int):
if not utility.has_collection(mt_collection_name):
self._create_shared_collection(mt_collection_name, dimension)
def has_collection(self, collection_name: str) -> bool:
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
if not utility.has_collection(mt_collection):
return False
collection = Collection(mt_collection)
collection.load()
res = collection.query(expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", limit=1)
return len(res) > 0
def upsert(self, collection_name: str, items: List[VectorItem]):
if not items:
return
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
dimension = len(items[0]["vector"])
self._ensure_collection(mt_collection, dimension)
collection = Collection(mt_collection)
entities = [
{
"id": item["id"],
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
RESOURCE_ID_FIELD: resource_id,
}
for item in items
]
collection.insert(entities)
def search(
self, collection_name: str, vectors: List[List[float]], limit: int
) -> Optional[SearchResult]:
if not vectors:
return None
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
if not utility.has_collection(mt_collection):
return None
collection = Collection(mt_collection)
collection.load()
search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}}
results = collection.search(
data=vectors,
anns_field="vector",
param=search_params,
limit=limit,
expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'",
output_fields=["id", "text", "metadata"],
)
ids, documents, metadatas, distances = [], [], [], []
for hits in results:
batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], []
for hit in hits:
batch_ids.append(hit.entity.get("id"))
batch_docs.append(hit.entity.get("text"))
batch_metadatas.append(hit.entity.get("metadata"))
batch_dists.append(hit.distance)
ids.append(batch_ids)
documents.append(batch_docs)
metadatas.append(batch_metadatas)
distances.append(batch_dists)
return SearchResult(
ids=ids, documents=documents, metadatas=metadatas, distances=distances
)
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict[str, Any]] = None,
):
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
if not utility.has_collection(mt_collection):
return
collection = Collection(mt_collection)
# Build expression
expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
if ids:
# Milvus expects a string list for 'in' operator
id_list_str = ", ".join([f"'{id_val}'" for id_val in ids])
expr.append(f"id in [{id_list_str}]")
if filter:
for key, value in filter.items():
expr.append(f"metadata['{key}'] == '{value}'")
collection.delete(" and ".join(expr))
def reset(self):
for collection_name in self.shared_collections:
if utility.has_collection(collection_name):
utility.drop_collection(collection_name)
def delete_collection(self, collection_name: str):
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
if not utility.has_collection(mt_collection):
return
collection = Collection(mt_collection)
collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'")
def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]:
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
if not utility.has_collection(mt_collection):
return None
collection = Collection(mt_collection)
collection.load()
expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
if filter:
for key, value in filter.items():
if isinstance(value, str):
expr.append(f"metadata['{key}'] == '{value}'")
else:
expr.append(f"metadata['{key}'] == {value}")
iterator = collection.query_iterator(
expr=" and ".join(expr),
output_fields=["id", "text", "metadata"],
limit=limit if limit else -1,
)
all_results = []
while True:
batch = iterator.next()
if not batch:
iterator.close()
break
all_results.extend(batch)
ids = [res["id"] for res in all_results]
documents = [res["text"] for res in all_results]
metadatas = [res["metadata"] for res in all_results]
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
def get(self, collection_name: str) -> Optional[GetResult]:
return self.query(collection_name, filter={}, limit=None)
def insert(self, collection_name: str, items: List[VectorItem]):
return self.upsert(collection_name, items)

View file

@ -2,7 +2,6 @@ from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk
from typing import Optional
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
@ -158,10 +157,10 @@ class OpenSearchClient(VectorDBBase):
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append(
{"term": {"metadata." + str(field) + ".keyword": value}}
{"match": {"metadata." + str(field): value}}
)
size = limit if limit else 10000
size = limit if limit else 10
try:
result = self.client.search(
@ -201,13 +200,12 @@ class OpenSearchClient(VectorDBBase):
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": process_metadata(item["metadata"]),
"metadata": item["metadata"],
},
}
for item in batch
]
bulk(self.client, actions)
self.client.indices.refresh(self._get_index_name(collection_name))
def upsert(self, collection_name: str, items: list[VectorItem]):
self._create_index_if_not_exists(
@ -223,14 +221,13 @@ class OpenSearchClient(VectorDBBase):
"doc": {
"vector": item["vector"],
"text": item["text"],
"metadata": process_metadata(item["metadata"]),
"metadata": item["metadata"],
},
"doc_as_upsert": True,
}
for item in batch
]
bulk(self.client, actions)
self.client.indices.refresh(self._get_index_name(collection_name))
def delete(
self,
@ -254,12 +251,11 @@ class OpenSearchClient(VectorDBBase):
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append(
{"term": {"metadata." + str(field) + ".keyword": value}}
{"match": {"metadata." + str(field): value}}
)
self.client.delete_by_query(
index=self._get_index_name(collection_name), body=query_body
)
self.client.indices.refresh(self._get_index_name(collection_name))
def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}_*")

View file

@ -1,941 +0,0 @@
"""
Oracle 23ai Vector Database Client - Fixed Version
# .env
VECTOR_DB = "oracle23ai"
## DBCS or oracle 23ai free
ORACLE_DB_USE_WALLET = false
ORACLE_DB_USER = "DEMOUSER"
ORACLE_DB_PASSWORD = "Welcome123456"
ORACLE_DB_DSN = "localhost:1521/FREEPDB1"
## ADW or ATP
# ORACLE_DB_USE_WALLET = true
# ORACLE_DB_USER = "DEMOUSER"
# ORACLE_DB_PASSWORD = "Welcome123456"
# ORACLE_DB_DSN = "medium"
# ORACLE_DB_DSN = "(description= (retry_count=3)(retry_delay=3)(address=(protocol=tcps)(port=1522)(host=xx.oraclecloud.com))(connect_data=(service_name=yy.adb.oraclecloud.com))(security=(ssl_server_dn_match=no)))"
# ORACLE_WALLET_DIR = "/home/opc/adb_wallet"
# ORACLE_WALLET_PASSWORD = "Welcome1"
ORACLE_VECTOR_LENGTH = 768
ORACLE_DB_POOL_MIN = 2
ORACLE_DB_POOL_MAX = 10
ORACLE_DB_POOL_INCREMENT = 1
"""
from typing import Optional, List, Dict, Any, Union
from decimal import Decimal
import logging
import os
import threading
import time
import json
import array
import oracledb
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
ORACLE_DB_USE_WALLET,
ORACLE_DB_USER,
ORACLE_DB_PASSWORD,
ORACLE_DB_DSN,
ORACLE_WALLET_DIR,
ORACLE_WALLET_PASSWORD,
ORACLE_VECTOR_LENGTH,
ORACLE_DB_POOL_MIN,
ORACLE_DB_POOL_MAX,
ORACLE_DB_POOL_INCREMENT,
)
log = logging.getLogger(__name__)
class Oracle23aiClient(VectorDBBase):
"""
Oracle Vector Database Client for vector similarity search using Oracle Database 23ai.
This client provides an interface to store, retrieve, and search vector embeddings
in an Oracle database. It uses connection pooling for efficient database access
and supports vector similarity search operations.
Attributes:
pool: Connection pool for Oracle database connections
"""
def __init__(self) -> None:
"""
Initialize the Oracle23aiClient with a connection pool.
Creates a connection pool with configurable min/max connections, initializes
the database schema if needed, and sets up necessary tables and indexes.
Raises:
ValueError: If required configuration parameters are missing
Exception: If database initialization fails
"""
self.pool = None
try:
# Create the appropriate connection pool based on DB type
if ORACLE_DB_USE_WALLET:
self._create_adb_pool()
else: # DBCS
self._create_dbcs_pool()
dsn = ORACLE_DB_DSN
log.info(f"Creating Connection Pool [{ORACLE_DB_USER}:**@{dsn}]")
with self.get_connection() as connection:
log.info(f"Connection version: {connection.version}")
self._initialize_database(connection)
log.info("Oracle Vector Search initialization complete.")
except Exception as e:
log.exception(f"Error during Oracle Vector Search initialization: {e}")
raise
def _create_adb_pool(self) -> None:
"""
Create connection pool for Oracle Autonomous Database.
Uses wallet-based authentication.
"""
self.pool = oracledb.create_pool(
user=ORACLE_DB_USER,
password=ORACLE_DB_PASSWORD,
dsn=ORACLE_DB_DSN,
min=ORACLE_DB_POOL_MIN,
max=ORACLE_DB_POOL_MAX,
increment=ORACLE_DB_POOL_INCREMENT,
config_dir=ORACLE_WALLET_DIR,
wallet_location=ORACLE_WALLET_DIR,
wallet_password=ORACLE_WALLET_PASSWORD,
)
log.info("Created ADB connection pool with wallet authentication.")
def _create_dbcs_pool(self) -> None:
"""
Create connection pool for Oracle Database Cloud Service.
Uses basic authentication without wallet.
"""
self.pool = oracledb.create_pool(
user=ORACLE_DB_USER,
password=ORACLE_DB_PASSWORD,
dsn=ORACLE_DB_DSN,
min=ORACLE_DB_POOL_MIN,
max=ORACLE_DB_POOL_MAX,
increment=ORACLE_DB_POOL_INCREMENT,
)
log.info("Created DB connection pool with basic authentication.")
def get_connection(self):
"""
Acquire a connection from the connection pool with retry logic.
Returns:
connection: A database connection with output type handler configured
"""
max_retries = 3
for attempt in range(max_retries):
try:
connection = self.pool.acquire()
connection.outputtypehandler = self._output_type_handler
return connection
except oracledb.DatabaseError as e:
(error_obj,) = e.args
log.exception(
f"Connection attempt {attempt + 1} failed: {error_obj.message}"
)
if attempt < max_retries - 1:
wait_time = 2**attempt
log.info(f"Retrying in {wait_time} seconds...")
time.sleep(wait_time)
else:
raise
def start_health_monitor(self, interval_seconds: int = 60):
"""
Start a background thread to periodically check the health of the connection pool.
Args:
interval_seconds (int): Number of seconds between health checks
"""
def _monitor():
while True:
try:
log.info("[HealthCheck] Running periodic DB health check...")
self.ensure_connection()
log.info("[HealthCheck] Connection is healthy.")
except Exception as e:
log.exception(f"[HealthCheck] Connection health check failed: {e}")
time.sleep(interval_seconds)
thread = threading.Thread(target=_monitor, daemon=True)
thread.start()
log.info(f"Started DB health monitor every {interval_seconds} seconds.")
def _reconnect_pool(self):
"""
Attempt to reinitialize the connection pool if it's been closed or broken.
"""
try:
log.info("Attempting to reinitialize the Oracle connection pool...")
# Close existing pool if it exists
if self.pool:
try:
self.pool.close()
except Exception as close_error:
log.warning(f"Error closing existing pool: {close_error}")
# Re-create the appropriate connection pool based on DB type
if ORACLE_DB_USE_WALLET:
self._create_adb_pool()
else: # DBCS
self._create_dbcs_pool()
log.info("Connection pool reinitialized.")
except Exception as e:
log.exception(f"Failed to reinitialize the connection pool: {e}")
raise
def ensure_connection(self):
"""
Ensure the database connection is alive, reconnecting pool if needed.
"""
try:
with self.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("SELECT 1 FROM dual")
except Exception as e:
log.exception(
f"Connection check failed: {e}, attempting to reconnect pool..."
)
self._reconnect_pool()
def _output_type_handler(self, cursor, metadata):
"""
Handle Oracle vector type conversion.
Args:
cursor: Oracle database cursor
metadata: Metadata for the column
Returns:
A variable with appropriate conversion for vector types
"""
if metadata.type_code is oracledb.DB_TYPE_VECTOR:
return cursor.var(
metadata.type_code, arraysize=cursor.arraysize, outconverter=list
)
def _initialize_database(self, connection) -> None:
"""
Initialize database schema, tables and indexes.
Creates the document_chunk table and necessary indexes if they don't exist.
Args:
connection: Oracle database connection
Raises:
Exception: If schema initialization fails
"""
with connection.cursor() as cursor:
try:
log.info("Creating Table document_chunk")
cursor.execute(
"""
BEGIN
EXECUTE IMMEDIATE '
CREATE TABLE IF NOT EXISTS document_chunk (
id VARCHAR2(255) PRIMARY KEY,
collection_name VARCHAR2(255) NOT NULL,
text CLOB,
vmetadata JSON,
vector vector(*, float32)
)
';
EXCEPTION
WHEN OTHERS THEN
IF SQLCODE != -955 THEN
RAISE;
END IF;
END;
"""
)
log.info("Creating Index document_chunk_collection_name_idx")
cursor.execute(
"""
BEGIN
EXECUTE IMMEDIATE '
CREATE INDEX IF NOT EXISTS document_chunk_collection_name_idx
ON document_chunk (collection_name)
';
EXCEPTION
WHEN OTHERS THEN
IF SQLCODE != -955 THEN
RAISE;
END IF;
END;
"""
)
log.info("Creating VECTOR INDEX document_chunk_vector_ivf_idx")
cursor.execute(
"""
BEGIN
EXECUTE IMMEDIATE '
CREATE VECTOR INDEX IF NOT EXISTS document_chunk_vector_ivf_idx
ON document_chunk(vector)
ORGANIZATION NEIGHBOR PARTITIONS
DISTANCE COSINE
WITH TARGET ACCURACY 95
PARAMETERS (TYPE IVF, NEIGHBOR PARTITIONS 100)
';
EXCEPTION
WHEN OTHERS THEN
IF SQLCODE != -955 THEN
RAISE;
END IF;
END;
"""
)
connection.commit()
log.info("Database initialization completed successfully.")
except Exception as e:
connection.rollback()
log.exception(f"Error during database initialization: {e}")
raise
def check_vector_length(self) -> None:
"""
Check vector length compatibility (placeholder).
This method would check if the configured vector length matches the database schema.
Currently implemented as a placeholder.
"""
pass
def _vector_to_blob(self, vector: List[float]) -> bytes:
"""
Convert a vector to Oracle BLOB format.
Args:
vector (List[float]): The vector to convert
Returns:
bytes: The vector in Oracle BLOB format
"""
return array.array("f", vector)
def adjust_vector_length(self, vector: List[float]) -> List[float]:
"""
Adjust vector to the expected length if needed.
Args:
vector (List[float]): The vector to adjust
Returns:
List[float]: The adjusted vector
"""
return vector
def _decimal_handler(self, obj):
"""
Handle Decimal objects for JSON serialization.
Args:
obj: Object to serialize
Returns:
float: Converted decimal value
Raises:
TypeError: If object is not JSON serializable
"""
if isinstance(obj, Decimal):
return float(obj)
raise TypeError(f"{obj} is not JSON serializable")
def _metadata_to_json(self, metadata: Dict) -> str:
"""
Convert metadata dictionary to JSON string.
Args:
metadata (Dict): Metadata dictionary
Returns:
str: JSON representation of metadata
"""
return json.dumps(metadata, default=self._decimal_handler) if metadata else "{}"
def _json_to_metadata(self, json_str: str) -> Dict:
"""
Convert JSON string to metadata dictionary.
Args:
json_str (str): JSON string
Returns:
Dict: Metadata dictionary
"""
return json.loads(json_str) if json_str else {}
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
"""
Insert vector items into the database.
Args:
collection_name (str): Name of the collection
items (List[VectorItem]): List of vector items to insert
Raises:
Exception: If insertion fails
Example:
>>> client = Oracle23aiClient()
>>> items = [
... {"id": "1", "text": "Sample text", "vector": [0.1, 0.2, ...], "metadata": {"source": "doc1"}},
... {"id": "2", "text": "Another text", "vector": [0.3, 0.4, ...], "metadata": {"source": "doc2"}}
... ]
>>> client.insert("my_collection", items)
"""
log.info(f"Inserting {len(items)} items into collection '{collection_name}'.")
with self.get_connection() as connection:
try:
with connection.cursor() as cursor:
for item in items:
vector_blob = self._vector_to_blob(item["vector"])
metadata_json = self._metadata_to_json(item["metadata"])
cursor.execute(
"""
INSERT INTO document_chunk
(id, collection_name, text, vmetadata, vector)
VALUES (:id, :collection_name, :text, :metadata, :vector)
""",
{
"id": item["id"],
"collection_name": collection_name,
"text": item["text"],
"metadata": metadata_json,
"vector": vector_blob,
},
)
connection.commit()
log.info(
f"Successfully inserted {len(items)} items into collection '{collection_name}'."
)
except Exception as e:
connection.rollback()
log.exception(f"Error during insert: {e}")
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
"""
Update or insert vector items into the database.
If an item with the same ID exists, it will be updated;
otherwise, it will be inserted.
Args:
collection_name (str): Name of the collection
items (List[VectorItem]): List of vector items to upsert
Raises:
Exception: If upsert operation fails
Example:
>>> client = Oracle23aiClient()
>>> items = [
... {"id": "1", "text": "Updated text", "vector": [0.1, 0.2, ...], "metadata": {"source": "doc1"}},
... {"id": "3", "text": "New item", "vector": [0.5, 0.6, ...], "metadata": {"source": "doc3"}}
... ]
>>> client.upsert("my_collection", items)
"""
log.info(f"Upserting {len(items)} items into collection '{collection_name}'.")
with self.get_connection() as connection:
try:
with connection.cursor() as cursor:
for item in items:
vector_blob = self._vector_to_blob(item["vector"])
metadata_json = self._metadata_to_json(item["metadata"])
cursor.execute(
"""
MERGE INTO document_chunk d
USING (SELECT :merge_id as id FROM dual) s
ON (d.id = s.id)
WHEN MATCHED THEN
UPDATE SET
collection_name = :upd_collection_name,
text = :upd_text,
vmetadata = :upd_metadata,
vector = :upd_vector
WHEN NOT MATCHED THEN
INSERT (id, collection_name, text, vmetadata, vector)
VALUES (:ins_id, :ins_collection_name, :ins_text, :ins_metadata, :ins_vector)
""",
{
"merge_id": item["id"],
"upd_collection_name": collection_name,
"upd_text": item["text"],
"upd_metadata": metadata_json,
"upd_vector": vector_blob,
"ins_id": item["id"],
"ins_collection_name": collection_name,
"ins_text": item["text"],
"ins_metadata": metadata_json,
"ins_vector": vector_blob,
},
)
connection.commit()
log.info(
f"Successfully upserted {len(items)} items into collection '{collection_name}'."
)
except Exception as e:
connection.rollback()
log.exception(f"Error during upsert: {e}")
raise
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
) -> Optional[SearchResult]:
"""
Search for similar vectors in the database.
Performs vector similarity search using cosine distance.
Args:
collection_name (str): Name of the collection to search
vectors (List[List[Union[float, int]]]): Query vectors to find similar items for
limit (int): Maximum number of results to return per query
Returns:
Optional[SearchResult]: Search results containing ids, distances, documents, and metadata
Example:
>>> client = Oracle23aiClient()
>>> query_vector = [0.1, 0.2, 0.3, ...] # Must match VECTOR_LENGTH
>>> results = client.search("my_collection", [query_vector], limit=5)
>>> if results:
... log.info(f"Found {len(results.ids[0])} matches")
... for i, (id, dist) in enumerate(zip(results.ids[0], results.distances[0])):
... log.info(f"Match {i+1}: id={id}, distance={dist}")
"""
log.info(
f"Searching items from collection '{collection_name}' with limit {limit}."
)
try:
if not vectors:
log.warning("No vectors provided for search.")
return None
num_queries = len(vectors)
ids = [[] for _ in range(num_queries)]
distances = [[] for _ in range(num_queries)]
documents = [[] for _ in range(num_queries)]
metadatas = [[] for _ in range(num_queries)]
with self.get_connection() as connection:
with connection.cursor() as cursor:
for qid, vector in enumerate(vectors):
vector_blob = self._vector_to_blob(vector)
cursor.execute(
"""
SELECT dc.id, dc.text,
JSON_SERIALIZE(dc.vmetadata RETURNING VARCHAR2(4096)) as vmetadata,
VECTOR_DISTANCE(dc.vector, :query_vector, COSINE) as distance
FROM document_chunk dc
WHERE dc.collection_name = :collection_name
ORDER BY VECTOR_DISTANCE(dc.vector, :query_vector, COSINE)
FETCH APPROX FIRST :limit ROWS ONLY
""",
{
"query_vector": vector_blob,
"collection_name": collection_name,
"limit": limit,
},
)
results = cursor.fetchall()
for row in results:
ids[qid].append(row[0])
documents[qid].append(
row[1].read()
if isinstance(row[1], oracledb.LOB)
else str(row[1])
)
# 🔧 FIXED: Parse JSON metadata properly
metadata_str = (
row[2].read()
if isinstance(row[2], oracledb.LOB)
else row[2]
)
metadatas[qid].append(self._json_to_metadata(metadata_str))
distances[qid].append(float(row[3]))
log.info(
f"Search completed. Found {sum(len(ids[i]) for i in range(num_queries))} total results."
)
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
except Exception as e:
log.exception(f"Error during search: {e}")
return None
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
"""
Query items based on metadata filters.
Retrieves items that match specified metadata criteria.
Args:
collection_name (str): Name of the collection to query
filter (Dict[str, Any]): Metadata filters to apply
limit (Optional[int]): Maximum number of results to return
Returns:
Optional[GetResult]: Query results containing ids, documents, and metadata
Example:
>>> client = Oracle23aiClient()
>>> filter = {"source": "doc1", "category": "finance"}
>>> results = client.query("my_collection", filter, limit=20)
>>> if results:
... print(f"Found {len(results.ids[0])} matching documents")
"""
log.info(f"Querying items from collection '{collection_name}' with filters.")
try:
limit = limit or 100
query = """
SELECT id, text, JSON_SERIALIZE(vmetadata RETURNING VARCHAR2(4096)) as vmetadata
FROM document_chunk
WHERE collection_name = :collection_name
"""
params = {"collection_name": collection_name}
for i, (key, value) in enumerate(filter.items()):
param_name = f"value_{i}"
query += f" AND JSON_VALUE(vmetadata, '$.{key}' RETURNING VARCHAR2(4096)) = :{param_name}"
params[param_name] = str(value)
query += " FETCH FIRST :limit ROWS ONLY"
params["limit"] = limit
with self.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute(query, params)
results = cursor.fetchall()
if not results:
log.info("No results found for query.")
return None
ids = [[row[0] for row in results]]
documents = [
[
row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1])
for row in results
]
]
# 🔧 FIXED: Parse JSON metadata properly
metadatas = [
[
self._json_to_metadata(
row[2].read() if isinstance(row[2], oracledb.LOB) else row[2]
)
for row in results
]
]
log.info(f"Query completed. Found {len(results)} results.")
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
log.exception(f"Error during query: {e}")
return None
def get(self, collection_name: str) -> Optional[GetResult]:
"""
Get all items in a collection.
Retrieves items from a specified collection up to the limit.
Args:
collection_name (str): Name of the collection to retrieve
limit (Optional[int]): Maximum number of items to retrieve
Returns:
Optional[GetResult]: Result containing ids, documents, and metadata
Example:
>>> client = Oracle23aiClient()
>>> results = client.get("my_collection", limit=50)
>>> if results:
... print(f"Retrieved {len(results.ids[0])} documents from collection")
"""
log.info(
f"Getting items from collection '{collection_name}' with limit {limit}."
)
try:
limit = 1000 # Hardcoded limit for get operation
with self.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute(
"""
SELECT /*+ MONITOR */ id, text, JSON_SERIALIZE(vmetadata RETURNING VARCHAR2(4096)) as vmetadata
FROM document_chunk
WHERE collection_name = :collection_name
FETCH FIRST :limit ROWS ONLY
""",
{"collection_name": collection_name, "limit": limit},
)
results = cursor.fetchall()
if not results:
log.info("No results found.")
return None
ids = [[row[0] for row in results]]
documents = [
[
row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1])
for row in results
]
]
# 🔧 FIXED: Parse JSON metadata properly
metadatas = [
[
self._json_to_metadata(
row[2].read() if isinstance(row[2], oracledb.LOB) else row[2]
)
for row in results
]
]
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
log.exception(f"Error during get: {e}")
return None
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict[str, Any]] = None,
) -> None:
"""
Delete items from the database.
Deletes items from a collection based on IDs or metadata filters.
Args:
collection_name (str): Name of the collection to delete from
ids (Optional[List[str]]): Specific item IDs to delete
filter (Optional[Dict[str, Any]]): Metadata filters for deletion
Raises:
Exception: If deletion fails
Example:
>>> client = Oracle23aiClient()
>>> # Delete specific items by ID
>>> client.delete("my_collection", ids=["1", "3", "5"])
>>> # Or delete by metadata filter
>>> client.delete("my_collection", filter={"source": "deprecated_source"})
"""
log.info(f"Deleting items from collection '{collection_name}'.")
try:
query = (
"DELETE FROM document_chunk WHERE collection_name = :collection_name"
)
params = {"collection_name": collection_name}
if ids:
# 🔧 FIXED: Use proper parameterized query to prevent SQL injection
placeholders = ",".join([f":id_{i}" for i in range(len(ids))])
query += f" AND id IN ({placeholders})"
for i, id_val in enumerate(ids):
params[f"id_{i}"] = id_val
if filter:
for i, (key, value) in enumerate(filter.items()):
param_name = f"value_{i}"
query += f" AND JSON_VALUE(vmetadata, '$.{key}' RETURNING VARCHAR2(4096)) = :{param_name}"
params[param_name] = str(value)
with self.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute(query, params)
deleted = cursor.rowcount
connection.commit()
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
except Exception as e:
log.exception(f"Error during delete: {e}")
raise
def reset(self) -> None:
"""
Reset the database by deleting all items.
Deletes all items from the document_chunk table.
Raises:
Exception: If reset fails
Example:
>>> client = Oracle23aiClient()
>>> client.reset() # Warning: Removes all data!
"""
log.info("Resetting database - deleting all items.")
try:
with self.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("DELETE FROM document_chunk")
deleted = cursor.rowcount
connection.commit()
log.info(
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
)
except Exception as e:
log.exception(f"Error during reset: {e}")
raise
def close(self) -> None:
"""
Close the database connection pool.
Properly closes the connection pool and releases all resources.
Example:
>>> client = Oracle23aiClient()
>>> # After finishing all operations
>>> client.close()
"""
try:
if hasattr(self, "pool") and self.pool:
self.pool.close()
log.info("Oracle Vector Search connection pool closed.")
except Exception as e:
log.exception(f"Error closing connection pool: {e}")
def has_collection(self, collection_name: str) -> bool:
"""
Check if a collection exists.
Args:
collection_name (str): Name of the collection to check
Returns:
bool: True if the collection exists, False otherwise
Example:
>>> client = Oracle23aiClient()
>>> if client.has_collection("my_collection"):
... print("Collection exists!")
... else:
... print("Collection does not exist.")
"""
try:
with self.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute(
"""
SELECT COUNT(*)
FROM document_chunk
WHERE collection_name = :collection_name
FETCH FIRST 1 ROWS ONLY
""",
{"collection_name": collection_name},
)
count = cursor.fetchone()[0]
return count > 0
except Exception as e:
log.exception(f"Error checking collection existence: {e}")
return False
def delete_collection(self, collection_name: str) -> None:
"""
Delete an entire collection.
Removes all items belonging to the specified collection.
Args:
collection_name (str): Name of the collection to delete
Example:
>>> client = Oracle23aiClient()
>>> client.delete_collection("obsolete_collection")
"""
log.info(f"Deleting collection '{collection_name}'.")
try:
with self.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute(
"""
DELETE FROM document_chunk
WHERE collection_name = :collection_name
""",
{"collection_name": collection_name},
)
deleted = cursor.rowcount
connection.commit()
log.info(
f"Collection '{collection_name}' deleted. Removed {deleted} items."
)
except Exception as e:
log.exception(f"Error deleting collection '{collection_name}': {e}")
raise

View file

@ -1,16 +1,12 @@
from typing import Optional, List, Dict, Any, Tuple
from typing import Optional, List, Dict, Any
import logging
import json
from sqlalchemy import (
func,
literal,
cast,
column,
create_engine,
Column,
Integer,
MetaData,
LargeBinary,
select,
text,
Text,
@ -18,71 +14,39 @@ from sqlalchemy import (
values,
)
from sqlalchemy.sql import true
from sqlalchemy.pool import NullPool, QueuePool
from sqlalchemy.pool import NullPool
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
from sqlalchemy.dialects.postgresql import JSONB, array
from pgvector.sqlalchemy import Vector, HALFVEC
from pgvector.sqlalchemy import Vector
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.exc import NoSuchTableError
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
PGVECTOR_DB_URL,
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
PGVECTOR_CREATE_EXTENSION,
PGVECTOR_PGCRYPTO,
PGVECTOR_PGCRYPTO_KEY,
PGVECTOR_POOL_SIZE,
PGVECTOR_POOL_MAX_OVERFLOW,
PGVECTOR_POOL_TIMEOUT,
PGVECTOR_POOL_RECYCLE,
PGVECTOR_INDEX_METHOD,
PGVECTOR_HNSW_M,
PGVECTOR_HNSW_EF_CONSTRUCTION,
PGVECTOR_IVFFLAT_LISTS,
PGVECTOR_USE_HALFVEC,
)
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
from open_webui.env import SRC_LOG_LEVELS
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
USE_HALFVEC = PGVECTOR_USE_HALFVEC
VECTOR_TYPE_FACTORY = HALFVEC if USE_HALFVEC else Vector
VECTOR_OPCLASS = "halfvec_cosine_ops" if USE_HALFVEC else "vector_cosine_ops"
Base = declarative_base()
log = logging.getLogger(__name__)
def pgcrypto_encrypt(val, key):
return func.pgp_sym_encrypt(val, literal(key))
def pgcrypto_decrypt(col, key, outtype="text"):
return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class DocumentChunk(Base):
__tablename__ = "document_chunk"
id = Column(Text, primary_key=True)
vector = Column(VECTOR_TYPE_FACTORY(dim=VECTOR_LENGTH), nullable=True)
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
collection_name = Column(Text, nullable=False)
if PGVECTOR_PGCRYPTO:
text = Column(LargeBinary, nullable=True)
vmetadata = Column(LargeBinary, nullable=True)
else:
text = Column(Text, nullable=True)
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
text = Column(Text, nullable=True)
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
class PgvectorClient(VectorDBBase):
@ -94,24 +58,9 @@ class PgvectorClient(VectorDBBase):
self.session = Session
else:
if isinstance(PGVECTOR_POOL_SIZE, int):
if PGVECTOR_POOL_SIZE > 0:
engine = create_engine(
PGVECTOR_DB_URL,
pool_size=PGVECTOR_POOL_SIZE,
max_overflow=PGVECTOR_POOL_MAX_OVERFLOW,
pool_timeout=PGVECTOR_POOL_TIMEOUT,
pool_recycle=PGVECTOR_POOL_RECYCLE,
pool_pre_ping=True,
poolclass=QueuePool,
)
else:
engine = create_engine(
PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
)
else:
engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True)
engine = create_engine(
PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
@ -119,41 +68,7 @@ class PgvectorClient(VectorDBBase):
try:
# Ensure the pgvector extension is available
# Use a conditional check to avoid permission issues on Azure PostgreSQL
if PGVECTOR_CREATE_EXTENSION:
self.session.execute(
text(
"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN
CREATE EXTENSION IF NOT EXISTS vector;
END IF;
END $$;
"""
)
)
if PGVECTOR_PGCRYPTO:
# Ensure the pgcrypto extension is available for encryption
# Use a conditional check to avoid permission issues on Azure PostgreSQL
self.session.execute(
text(
"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN
CREATE EXTENSION IF NOT EXISTS pgcrypto;
END IF;
END $$;
"""
)
)
if not PGVECTOR_PGCRYPTO_KEY:
raise ValueError(
"PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled."
)
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
# Check vector length consistency
self.check_vector_length()
@ -164,9 +79,13 @@ class PgvectorClient(VectorDBBase):
connection = self.session.connection()
Base.metadata.create_all(bind=connection)
index_method, index_options = self._vector_index_configuration()
self._ensure_vector_index(index_method, index_options)
# Create an index on the vector column if it doesn't exist
self.session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
"ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
)
)
self.session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
@ -180,78 +99,6 @@ class PgvectorClient(VectorDBBase):
log.exception(f"Error during initialization: {e}")
raise
@staticmethod
def _extract_index_method(index_def: Optional[str]) -> Optional[str]:
if not index_def:
return None
try:
after_using = index_def.lower().split("using ", 1)[1]
return after_using.split()[0]
except (IndexError, AttributeError):
return None
def _vector_index_configuration(self) -> Tuple[str, str]:
if PGVECTOR_INDEX_METHOD:
index_method = PGVECTOR_INDEX_METHOD
log.info(
"Using vector index method '%s' from PGVECTOR_INDEX_METHOD.",
index_method,
)
elif USE_HALFVEC:
index_method = "hnsw"
log.info(
"VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.",
VECTOR_LENGTH,
)
else:
index_method = "ivfflat"
if index_method == "hnsw":
index_options = f"WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})"
else:
index_options = f"WITH (lists = {PGVECTOR_IVFFLAT_LISTS})"
return index_method, index_options
def _ensure_vector_index(self, index_method: str, index_options: str) -> None:
index_name = "idx_document_chunk_vector"
existing_index_def = self.session.execute(
text(
"""
SELECT indexdef
FROM pg_indexes
WHERE schemaname = current_schema()
AND tablename = 'document_chunk'
AND indexname = :index_name
"""
),
{"index_name": index_name},
).scalar()
existing_method = self._extract_index_method(existing_index_def)
if existing_method and existing_method != index_method:
raise RuntimeError(
f"Existing pgvector index '{index_name}' uses method '{existing_method}' but configuration now "
f"requires '{index_method}'. Automatic rebuild is disabled to prevent long-running maintenance. "
"Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) "
"and recreate it with the new method before restarting Open WebUI."
)
if not existing_index_def:
index_sql = (
f"CREATE INDEX IF NOT EXISTS {index_name} "
f"ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})"
)
if index_options:
index_sql = f"{index_sql} {index_options}"
self.session.execute(text(index_sql))
log.info(
"Ensured vector index '%s' using %s%s.",
index_name,
index_method,
f" {index_options}" if index_options else "",
)
def check_vector_length(self) -> None:
"""
Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
@ -271,19 +118,16 @@ class PgvectorClient(VectorDBBase):
if "vector" in document_chunk_table.columns:
vector_column = document_chunk_table.columns["vector"]
vector_type = vector_column.type
expected_type = HALFVEC if USE_HALFVEC else Vector
if not isinstance(vector_type, expected_type):
if isinstance(vector_type, Vector):
db_vector_length = vector_type.dim
if db_vector_length != VECTOR_LENGTH:
raise Exception(
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
"Cannot change vector size after initialization without migrating the data."
)
else:
raise Exception(
"The 'vector' column type does not match the expected type "
f"('{expected_type.__name__}') for VECTOR_LENGTH {VECTOR_LENGTH}."
)
db_vector_length = getattr(vector_type, "dim", None)
if db_vector_length is not None and db_vector_length != VECTOR_LENGTH:
raise Exception(
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
"Cannot change vector size after initialization without migrating the data."
"The 'vector' column exists but is not of type 'Vector'."
)
else:
raise Exception(
@ -303,54 +147,22 @@ class PgvectorClient(VectorDBBase):
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
if PGVECTOR_PGCRYPTO:
for item in items:
vector = self.adjust_vector_length(item["vector"])
# Use raw SQL for BYTEA/pgcrypto
# Ensure metadata is converted to its JSON text representation
json_metadata = json.dumps(item["metadata"])
self.session.execute(
text(
"""
INSERT INTO document_chunk
(id, vector, collection_name, text, vmetadata)
VALUES (
:id, :vector, :collection_name,
pgp_sym_encrypt(:text, :key),
pgp_sym_encrypt(:metadata_text, :key)
)
ON CONFLICT (id) DO NOTHING
"""
),
{
"id": item["id"],
"vector": vector,
"collection_name": collection_name,
"text": item["text"],
"metadata_text": json_metadata,
"key": PGVECTOR_PGCRYPTO_KEY,
},
)
self.session.commit()
log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'")
else:
new_items = []
for item in items:
vector = self.adjust_vector_length(item["vector"])
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=process_metadata(item["metadata"]),
)
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
log.info(
f"Inserted {len(new_items)} items into collection '{collection_name}'."
new_items = []
for item in items:
vector = self.adjust_vector_length(item["vector"])
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=item["metadata"],
)
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
log.info(
f"Inserted {len(new_items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
log.exception(f"Error during insert: {e}")
@ -358,66 +170,33 @@ class PgvectorClient(VectorDBBase):
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
if PGVECTOR_PGCRYPTO:
for item in items:
vector = self.adjust_vector_length(item["vector"])
json_metadata = json.dumps(item["metadata"])
self.session.execute(
text(
"""
INSERT INTO document_chunk
(id, vector, collection_name, text, vmetadata)
VALUES (
:id, :vector, :collection_name,
pgp_sym_encrypt(:text, :key),
pgp_sym_encrypt(:metadata_text, :key)
)
ON CONFLICT (id) DO UPDATE SET
vector = EXCLUDED.vector,
collection_name = EXCLUDED.collection_name,
text = EXCLUDED.text,
vmetadata = EXCLUDED.vmetadata
"""
),
{
"id": item["id"],
"vector": vector,
"collection_name": collection_name,
"text": item["text"],
"metadata_text": json_metadata,
"key": PGVECTOR_PGCRYPTO_KEY,
},
)
self.session.commit()
log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
else:
for item in items:
vector = self.adjust_vector_length(item["vector"])
existing = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.id == item["id"])
.first()
)
if existing:
existing.vector = vector
existing.text = item["text"]
existing.vmetadata = process_metadata(item["metadata"])
existing.collection_name = (
collection_name # Update collection_name if necessary
)
else:
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=process_metadata(item["metadata"]),
)
self.session.add(new_chunk)
self.session.commit()
log.info(
f"Upserted {len(items)} items into collection '{collection_name}'."
for item in items:
vector = self.adjust_vector_length(item["vector"])
existing = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.id == item["id"])
.first()
)
if existing:
existing.vector = vector
existing.text = item["text"]
existing.vmetadata = item["metadata"]
existing.collection_name = (
collection_name # Update collection_name if necessary
)
else:
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=item["metadata"],
)
self.session.add(new_chunk)
self.session.commit()
log.info(
f"Upserted {len(items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
log.exception(f"Error during upsert: {e}")
@ -438,11 +217,11 @@ class PgvectorClient(VectorDBBase):
num_queries = len(vectors)
def vector_expr(vector):
return cast(array(vector), VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
return cast(array(vector), Vector(VECTOR_LENGTH))
# Create the values for query vectors
qid_col = column("qid", Integer)
q_vector_col = column("q_vector", VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
query_vectors = (
values(qid_col, q_vector_col)
.data(
@ -451,32 +230,16 @@ class PgvectorClient(VectorDBBase):
.alias("query_vectors")
)
result_fields = [
DocumentChunk.id,
]
if PGVECTOR_PGCRYPTO:
result_fields.append(
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text")
)
result_fields.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata")
)
else:
result_fields.append(DocumentChunk.text)
result_fields.append(DocumentChunk.vmetadata)
result_fields.append(
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
"distance"
)
)
# Build the lateral subquery for each query vector
subq = (
select(*result_fields)
select(
DocumentChunk.id,
DocumentChunk.text,
DocumentChunk.vmetadata,
(
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
).label("distance"),
)
.where(DocumentChunk.collection_name == collection_name)
.order_by(
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
@ -525,12 +288,10 @@ class PgvectorClient(VectorDBBase):
documents[qid].append(row.text)
metadatas[qid].append(row.vmetadata)
self.session.rollback() # read-only transaction
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
except Exception as e:
self.session.rollback()
log.exception(f"Error during search: {e}")
return None
@ -538,43 +299,17 @@ class PgvectorClient(VectorDBBase):
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]:
try:
if PGVECTOR_PGCRYPTO:
# Build where clause for vmetadata filter
where_clauses = [DocumentChunk.collection_name == collection_name]
for key, value in filter.items():
# decrypt then check key: JSON filter after decryption
where_clauses.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
)[key].astext
== str(value)
)
stmt = select(
DocumentChunk.id,
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text"),
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata"),
).where(*where_clauses)
if limit is not None:
stmt = stmt.limit(limit)
results = self.session.execute(stmt).all()
else:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
for key, value in filter.items():
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
if limit is not None:
query = query.limit(limit)
if limit is not None:
query = query.limit(limit)
results = query.all()
results = query.all()
if not results:
return None
@ -583,14 +318,12 @@ class PgvectorClient(VectorDBBase):
documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]]
self.session.rollback() # read-only transaction
return GetResult(
ids=ids,
documents=documents,
metadatas=metadatas,
)
except Exception as e:
self.session.rollback()
log.exception(f"Error during query: {e}")
return None
@ -598,43 +331,23 @@ class PgvectorClient(VectorDBBase):
self, collection_name: str, limit: Optional[int] = None
) -> Optional[GetResult]:
try:
if PGVECTOR_PGCRYPTO:
stmt = select(
DocumentChunk.id,
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text"),
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata"),
).where(DocumentChunk.collection_name == collection_name)
if limit is not None:
stmt = stmt.limit(limit)
results = self.session.execute(stmt).all()
ids = [[row.id for row in results]]
documents = [[row.text for row in results]]
metadatas = [[row.vmetadata for row in results]]
else:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if limit is not None:
query = query.limit(limit)
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if limit is not None:
query = query.limit(limit)
results = query.all()
results = query.all()
if not results:
return None
if not results:
return None
ids = [[result.id for result in results]]
documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]]
ids = [[result.id for result in results]]
documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]]
self.session.rollback() # read-only transaction
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
self.session.rollback()
log.exception(f"Error during get: {e}")
return None
@ -645,33 +358,17 @@ class PgvectorClient(VectorDBBase):
filter: Optional[Dict[str, Any]] = None,
) -> None:
try:
if PGVECTOR_PGCRYPTO:
wheres = [DocumentChunk.collection_name == collection_name]
if ids:
wheres.append(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
wheres.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
)[key].astext
== str(value)
)
stmt = DocumentChunk.__table__.delete().where(*wheres)
result = self.session.execute(stmt)
deleted = result.rowcount
else:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if ids:
query = query.filter(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
deleted = query.delete(synchronize_session=False)
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if ids:
query = query.filter(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
deleted = query.delete(synchronize_session=False)
self.session.commit()
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
except Exception as e:
@ -702,10 +399,8 @@ class PgvectorClient(VectorDBBase):
.first()
is not None
)
self.session.rollback() # read-only transaction
return exists
except Exception as e:
self.session.rollback()
log.exception(f"Error checking collection existence: {e}")
return False

View file

@ -3,19 +3,10 @@ import logging
import time # for measuring elapsed time
from pinecone import Pinecone, ServerlessSpec
# Add gRPC support for better performance (Pinecone best practice)
try:
from pinecone.grpc import PineconeGRPC
GRPC_AVAILABLE = True
except ImportError:
GRPC_AVAILABLE = False
import asyncio # for async upserts
import functools # for partial binding in async tasks
import concurrent.futures # for parallel batch upserts
import random # for jitter in retry backoff
from open_webui.retrieval.vector.main import (
VectorDBBase,
@ -31,13 +22,13 @@ from open_webui.config import (
PINECONE_METRIC,
PINECONE_CLOUD,
)
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.env import SRC_LOG_LEVELS
NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class PineconeClient(VectorDBBase):
@ -56,24 +47,7 @@ class PineconeClient(VectorDBBase):
self.cloud = PINECONE_CLOUD
# Initialize Pinecone client for improved performance
if GRPC_AVAILABLE:
# Use gRPC client for better performance (Pinecone recommendation)
self.client = PineconeGRPC(
api_key=self.api_key,
pool_threads=20, # Improved connection pool size
timeout=30, # Reasonable timeout for operations
)
self.using_grpc = True
log.info("Using Pinecone gRPC client for optimal performance")
else:
# Fallback to HTTP client with enhanced connection pooling
self.client = Pinecone(
api_key=self.api_key,
pool_threads=20, # Improved connection pool size
timeout=30, # Reasonable timeout for operations
)
self.using_grpc = False
log.info("Using Pinecone HTTP client (gRPC not available)")
self.client = Pinecone(api_key=self.api_key)
# Persistent executor for batch operations
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
@ -117,53 +91,12 @@ class PineconeClient(VectorDBBase):
log.info(f"Using existing Pinecone index '{self.index_name}'")
# Connect to the index
self.index = self.client.Index(
self.index_name,
pool_threads=20, # Enhanced connection pool for index operations
)
self.index = self.client.Index(self.index_name)
except Exception as e:
log.error(f"Failed to initialize Pinecone index: {e}")
raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
def _retry_pinecone_operation(self, operation_func, max_retries=3):
"""Retry Pinecone operations with exponential backoff for rate limits and network issues."""
for attempt in range(max_retries):
try:
return operation_func()
except Exception as e:
error_str = str(e).lower()
# Check if it's a retryable error (rate limits, network issues, timeouts)
is_retryable = any(
keyword in error_str
for keyword in [
"rate limit",
"quota",
"timeout",
"network",
"connection",
"unavailable",
"internal error",
"429",
"500",
"502",
"503",
"504",
]
)
if not is_retryable or attempt == max_retries - 1:
# Don't retry for non-retryable errors or on final attempt
raise
# Exponential backoff with jitter
delay = (2**attempt) + random.uniform(0, 1)
log.warning(
f"Pinecone operation failed (attempt {attempt + 1}/{max_retries}), "
f"retrying in {delay:.2f}s: {e}"
)
time.sleep(delay)
def _create_points(
self, items: List[VectorItem], collection_name_with_prefix: str
) -> List[Dict[str, Any]]:
@ -183,7 +116,7 @@ class PineconeClient(VectorDBBase):
point = {
"id": item["id"],
"values": item["vector"],
"metadata": process_metadata(metadata),
"metadata": metadata,
}
points.append(point)
return points
@ -290,8 +223,7 @@ class PineconeClient(VectorDBBase):
elapsed = time.time() - start_time
log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
log.info(
f"Successfully inserted {len(points)} vectors in parallel batches "
f"into '{collection_name_with_prefix}'"
f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
)
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
@ -322,8 +254,7 @@ class PineconeClient(VectorDBBase):
elapsed = time.time() - start_time
log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
log.info(
f"Successfully upserted {len(points)} vectors in parallel batches "
f"into '{collection_name_with_prefix}'"
f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
)
async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
@ -354,8 +285,7 @@ class PineconeClient(VectorDBBase):
log.error(f"Error in async insert batch: {result}")
raise result
log.info(
f"Successfully async inserted {len(points)} vectors in batches "
f"into '{collection_name_with_prefix}'"
f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
)
async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
@ -386,8 +316,7 @@ class PineconeClient(VectorDBBase):
log.error(f"Error in async upsert batch: {result}")
raise result
log.info(
f"Successfully async upserted {len(points)} vectors in batches "
f"into '{collection_name_with_prefix}'"
f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
)
def search(
@ -528,12 +457,10 @@ class PineconeClient(VectorDBBase):
# This is a limitation of Pinecone - be careful with ID uniqueness
self.index.delete(ids=batch_ids)
log.debug(
f"Deleted batch of {len(batch_ids)} vectors by ID "
f"from '{collection_name_with_prefix}'"
f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
)
log.info(
f"Successfully deleted {len(ids)} vectors by ID "
f"from '{collection_name_with_prefix}'"
f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
)
elif filter:

View file

@ -18,26 +18,23 @@ from open_webui.config import (
QDRANT_ON_DISK,
QDRANT_GRPC_PORT,
QDRANT_PREFER_GRPC,
QDRANT_COLLECTION_PREFIX,
QDRANT_TIMEOUT,
QDRANT_HNSW_M,
)
from open_webui.env import SRC_LOG_LEVELS
NO_LIMIT = 999999999
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient(VectorDBBase):
def __init__(self):
self.collection_prefix = QDRANT_COLLECTION_PREFIX
self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY
self.QDRANT_ON_DISK = QDRANT_ON_DISK
self.PREFER_GRPC = QDRANT_PREFER_GRPC
self.GRPC_PORT = QDRANT_GRPC_PORT
self.QDRANT_TIMEOUT = QDRANT_TIMEOUT
self.QDRANT_HNSW_M = QDRANT_HNSW_M
if not self.QDRANT_URI:
self.client = None
@ -55,14 +52,9 @@ class QdrantClient(VectorDBBase):
grpc_port=self.GRPC_PORT,
prefer_grpc=self.PREFER_GRPC,
api_key=self.QDRANT_API_KEY,
timeout=self.QDRANT_TIMEOUT,
)
else:
self.client = Qclient(
url=self.QDRANT_URI,
api_key=self.QDRANT_API_KEY,
timeout=QDRANT_TIMEOUT,
)
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
def _result_to_get_result(self, points) -> GetResult:
ids = []
@ -92,30 +84,8 @@ class QdrantClient(VectorDBBase):
distance=models.Distance.COSINE,
on_disk=self.QDRANT_ON_DISK,
),
hnsw_config=models.HnswConfigDiff(
m=self.QDRANT_HNSW_M,
),
)
# Create payload indexes for efficient filtering
self.client.create_payload_index(
collection_name=collection_name_with_prefix,
field_name="metadata.hash",
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=False,
on_disk=self.QDRANT_ON_DISK,
),
)
self.client.create_payload_index(
collection_name=collection_name_with_prefix,
field_name="metadata.file_id",
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=False,
on_disk=self.QDRANT_ON_DISK,
),
)
log.info(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension):
@ -181,23 +151,23 @@ class QdrantClient(VectorDBBase):
)
)
points = self.client.scroll(
points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
scroll_filter=models.Filter(should=field_conditions),
query_filter=models.Filter(should=field_conditions),
limit=limit,
)
return self._result_to_get_result(points[0])
return self._result_to_get_result(points.points)
except Exception as e:
log.exception(f"Error querying a collection '{collection_name}': {e}")
return None
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
points = self.client.scroll(
points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
)
return self._result_to_get_result(points[0])
return self._result_to_get_result(points.points)
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.

View file

@ -1,5 +1,5 @@
import logging
from typing import Optional, Tuple, List, Dict, Any
from typing import Optional, Tuple
from urllib.parse import urlparse
import grpc
@ -9,10 +9,8 @@ from open_webui.config import (
QDRANT_ON_DISK,
QDRANT_PREFER_GRPC,
QDRANT_URI,
QDRANT_COLLECTION_PREFIX,
QDRANT_TIMEOUT,
QDRANT_HNSW_M,
)
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.vector.main import (
GetResult,
SearchResult,
@ -25,61 +23,39 @@ from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
NO_LIMIT = 999999999
TENANT_ID_FIELD = "tenant_id"
DEFAULT_DIMENSION = 384
log = logging.getLogger(__name__)
def _tenant_filter(tenant_id: str) -> models.FieldCondition:
return models.FieldCondition(
key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
)
def _metadata_filter(key: str, value: Any) -> models.FieldCondition:
return models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient(VectorDBBase):
def __init__(self):
self.collection_prefix = QDRANT_COLLECTION_PREFIX
self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY
self.QDRANT_ON_DISK = QDRANT_ON_DISK
self.PREFER_GRPC = QDRANT_PREFER_GRPC
self.GRPC_PORT = QDRANT_GRPC_PORT
self.QDRANT_TIMEOUT = QDRANT_TIMEOUT
self.QDRANT_HNSW_M = QDRANT_HNSW_M
if not self.QDRANT_URI:
raise ValueError(
"QDRANT_URI is not set. Please configure it in the environment variables."
)
self.client = None
return
# Unified handling for either scheme
parsed = urlparse(self.QDRANT_URI)
host = parsed.hostname or self.QDRANT_URI
http_port = parsed.port or 6333 # default REST port
self.client = (
Qclient(
if self.PREFER_GRPC:
self.client = Qclient(
host=host,
port=http_port,
grpc_port=self.GRPC_PORT,
prefer_grpc=self.PREFER_GRPC,
api_key=self.QDRANT_API_KEY,
timeout=self.QDRANT_TIMEOUT,
)
if self.PREFER_GRPC
else Qclient(
url=self.QDRANT_URI,
api_key=self.QDRANT_API_KEY,
timeout=self.QDRANT_TIMEOUT,
)
)
else:
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
# Main collection types for multi-tenancy
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
@ -89,13 +65,23 @@ class QdrantClient(VectorDBBase):
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
def _result_to_get_result(self, points) -> GetResult:
ids, documents, metadatas = [], [], []
ids = []
documents = []
metadatas = []
for point in points:
payload = point.payload
ids.append(point.id)
documents.append(payload["text"])
metadatas.append(payload["metadata"])
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
"""
@ -103,13 +89,6 @@ class QdrantClient(VectorDBBase):
Returns:
tuple: (collection_name, tenant_id)
WARNING: This mapping relies on current Open WebUI naming conventions for
collection names. If Open WebUI changes how it generates collection names
(e.g., "user-memory-" prefix, "file-" prefix, web search patterns, or hash
formats), this mapping will break and route data to incorrect collections.
POTENTIALLY CAUSING HUGE DATA CORRUPTION, DATA CONSISTENCY ISSUES AND INCORRECT
DATA MAPPING INSIDE THE DATABASE.
"""
# Check for user memory collections
tenant_id = collection_name
@ -134,53 +113,143 @@ class QdrantClient(VectorDBBase):
else:
return self.KNOWLEDGE_COLLECTION, tenant_id
def _create_multi_tenant_collection(
self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
def _extract_error_message(self, exception):
"""
Extract error message from either HTTP or gRPC exceptions
Returns:
tuple: (status_code, error_message)
"""
# Check if it's an HTTP exception
if isinstance(exception, UnexpectedResponse):
try:
error_data = exception.structured()
error_msg = error_data.get("status", {}).get("error", "")
return exception.status_code, error_msg
except Exception as inner_e:
log.error(f"Failed to parse HTTP error: {inner_e}")
return exception.status_code, str(exception)
# Check if it's a gRPC exception
elif isinstance(exception, grpc.RpcError):
# Extract status code from gRPC error
status_code = None
if hasattr(exception, "code") and callable(exception.code):
status_code = exception.code().value[0]
# Extract error message
error_msg = str(exception)
if "details =" in error_msg:
# Parse the details line which contains the actual error message
try:
details_line = [
line.strip()
for line in error_msg.split("\n")
if "details =" in line
][0]
error_msg = details_line.split("details =")[1].strip(' "')
except (IndexError, AttributeError):
# Fall back to full message if parsing fails
pass
return status_code, error_msg
# For any other type of exception
return None, str(exception)
def _is_collection_not_found_error(self, exception):
"""
Check if the exception is due to collection not found, supporting both HTTP and gRPC
"""
status_code, error_msg = self._extract_error_message(exception)
# HTTP error (404)
if (
status_code == 404
and "Collection" in error_msg
and "doesn't exist" in error_msg
):
return True
# gRPC error (NOT_FOUND status)
if (
isinstance(exception, grpc.RpcError)
and exception.code() == grpc.StatusCode.NOT_FOUND
):
return True
return False
def _is_dimension_mismatch_error(self, exception):
"""
Check if the exception is due to dimension mismatch, supporting both HTTP and gRPC
"""
status_code, error_msg = self._extract_error_message(exception)
# Common patterns in both HTTP and gRPC
return (
"Vector dimension error" in error_msg
or "dimensions mismatch" in error_msg
or "invalid vector size" in error_msg
)
def _create_multi_tenant_collection_if_not_exists(
self, mt_collection_name: str, dimension: int = 384
):
"""
Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields.
Creates a collection with multi-tenancy configuration if it doesn't exist.
Default dimension is set to 384 which corresponds to 'sentence-transformers/all-MiniLM-L6-v2'.
When creating collections dynamically (insert/upsert), the actual vector dimensions will be used.
"""
self.client.create_collection(
collection_name=mt_collection_name,
vectors_config=models.VectorParams(
size=dimension,
distance=models.Distance.COSINE,
on_disk=self.QDRANT_ON_DISK,
),
# Disable global index building due to multitenancy
# For more details https://qdrant.tech/documentation/guides/multiple-partitions/#calibrate-performance
hnsw_config=models.HnswConfigDiff(
payload_m=self.QDRANT_HNSW_M,
m=0,
),
)
log.info(
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
)
self.client.create_payload_index(
collection_name=mt_collection_name,
field_name=TENANT_ID_FIELD,
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=True,
on_disk=self.QDRANT_ON_DISK,
),
)
for field in ("metadata.hash", "metadata.file_id"):
self.client.create_payload_index(
try:
# Try to create the collection directly - will fail if it already exists
self.client.create_collection(
collection_name=mt_collection_name,
field_name=field,
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
vectors_config=models.VectorParams(
size=dimension,
distance=models.Distance.COSINE,
on_disk=self.QDRANT_ON_DISK,
),
hnsw_config=models.HnswConfigDiff(
payload_m=16, # Enable per-tenant indexing
m=0,
on_disk=self.QDRANT_ON_DISK,
),
)
def _create_points(
self, items: List[VectorItem], tenant_id: str
) -> List[PointStruct]:
# Create tenant ID payload index
self.client.create_payload_index(
collection_name=mt_collection_name,
field_name="tenant_id",
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=True,
on_disk=self.QDRANT_ON_DISK,
),
wait=True,
)
log.info(
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
)
except (UnexpectedResponse, grpc.RpcError) as e:
# Check for the specific error indicating collection already exists
status_code, error_msg = self._extract_error_message(e)
# HTTP status code 409 or gRPC ALREADY_EXISTS
if (isinstance(e, UnexpectedResponse) and status_code == 409) or (
isinstance(e, grpc.RpcError)
and e.code() == grpc.StatusCode.ALREADY_EXISTS
):
if "already exists" in error_msg:
log.debug(f"Collection {mt_collection_name} already exists")
return
# If it's not an already exists error, re-raise
raise e
except Exception as e:
raise e
def _create_points(self, items: list[VectorItem], tenant_id: str):
"""
Create point structs from vector items with tenant ID.
"""
@ -191,42 +260,56 @@ class QdrantClient(VectorDBBase):
payload={
"text": item["text"],
"metadata": item["metadata"],
TENANT_ID_FIELD: tenant_id,
"tenant_id": tenant_id,
},
)
for item in items
]
def _ensure_collection(
self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
):
"""
Ensure the collection exists and payload indexes are created for tenant_id and metadata fields.
"""
if not self.client.collection_exists(collection_name=mt_collection_name):
self._create_multi_tenant_collection(mt_collection_name, dimension)
def has_collection(self, collection_name: str) -> bool:
"""
Check if a logical collection exists by checking for any points with the tenant ID.
"""
if not self.client:
return False
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection):
return False
tenant_filter = _tenant_filter(tenant_id)
count_result = self.client.count(
collection_name=mt_collection,
count_filter=models.Filter(must=[tenant_filter]),
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
return count_result.count > 0
try:
# Try directly querying - most of the time collection should exist
response = self.client.query_points(
collection_name=mt_collection,
query_filter=models.Filter(must=[tenant_filter]),
limit=1,
)
# Collection exists with this tenant ID if there are points
return len(response.points) > 0
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(f"Collection {mt_collection} doesn't exist")
return False
else:
# For other API errors, log and return False
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error: {error_msg}")
return False
except Exception as e:
# For any other errors, log and return False
log.debug(f"Error checking collection {mt_collection}: {e}")
return False
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict[str, Any]] = None,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
"""
Delete vectors by ID or filter from a collection with tenant isolation.
@ -234,76 +317,189 @@ class QdrantClient(VectorDBBase):
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection):
log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
return None
must_conditions = [_tenant_filter(tenant_id)]
should_conditions = []
if ids:
should_conditions = [_metadata_filter("id", id_value) for id_value in ids]
elif filter:
must_conditions += [_metadata_filter(k, v) for k, v in filter.items()]
return self.client.delete(
collection_name=mt_collection,
points_selector=models.FilterSelector(
filter=models.Filter(must=must_conditions, should=should_conditions)
),
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
must_conditions = [tenant_filter]
should_conditions = []
if ids:
for id_value in ids:
should_conditions.append(
models.FieldCondition(
key="metadata.id",
match=models.MatchValue(value=id_value),
),
)
elif filter:
for key, value in filter.items():
must_conditions.append(
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
)
try:
# Try to delete directly - most of the time collection should exist
update_result = self.client.delete(
collection_name=mt_collection,
points_selector=models.FilterSelector(
filter=models.Filter(must=must_conditions, should=should_conditions)
),
)
return update_result
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(
f"Collection {mt_collection} doesn't exist, nothing to delete"
)
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, re-raise
raise
def search(
self, collection_name: str, vectors: List[List[float | int]], limit: int
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
"""
Search for the nearest neighbor items based on the vectors with tenant isolation.
"""
if not self.client or not vectors:
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection):
log.debug(f"Collection {mt_collection} doesn't exist, search returns None")
# Get the vector dimension from the query vector
dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None
try:
# Try the search operation directly - most of the time collection should exist
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
# Ensure vector dimensions match the collection
collection_dim = self.client.get_collection(
mt_collection
).config.params.vectors.size
if collection_dim != dimension:
if collection_dim < dimension:
vectors = [vector[:collection_dim] for vector in vectors]
else:
vectors = [
vector + [0] * (collection_dim - dimension)
for vector in vectors
]
# Search with tenant filter
prefetch_query = models.Prefetch(
filter=models.Filter(must=[tenant_filter]),
limit=NO_LIMIT,
)
query_response = self.client.query_points(
collection_name=mt_collection,
query=vectors[0],
prefetch=prefetch_query,
limit=limit,
)
get_result = self._result_to_get_result(query_response.points)
return SearchResult(
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
# qdrant distance is [-1, 1], normalize to [0, 1]
distances=[
[(point.score + 1.0) / 2.0 for point in query_response.points]
],
)
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(
f"Collection {mt_collection} doesn't exist, search returns None"
)
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error during search: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, log and return None
log.exception(f"Error searching collection '{collection_name}': {e}")
return None
tenant_filter = _tenant_filter(tenant_id)
query_response = self.client.query_points(
collection_name=mt_collection,
query=vectors[0],
limit=limit,
query_filter=models.Filter(must=[tenant_filter]),
)
get_result = self._result_to_get_result(query_response.points)
return SearchResult(
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]],
)
def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
):
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
"""
Query points with filters and tenant isolation.
"""
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection):
log.debug(f"Collection {mt_collection} doesn't exist, query returns None")
return None
# Set default limit if not provided
if limit is None:
limit = NO_LIMIT
tenant_filter = _tenant_filter(tenant_id)
field_conditions = [_metadata_filter(k, v) for k, v in filter.items()]
combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
points = self.client.scroll(
collection_name=mt_collection,
scroll_filter=combined_filter,
limit=limit,
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
return self._result_to_get_result(points[0])
# Create metadata filters
field_conditions = []
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
)
# Combine tenant filter with metadata filters
combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
try:
# Try the query directly - most of the time collection should exist
points = self.client.query_points(
collection_name=mt_collection,
query_filter=combined_filter,
limit=limit,
)
return self._result_to_get_result(points.points)
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(
f"Collection {mt_collection} doesn't exist, query returns None"
)
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error during query: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, log and re-raise
log.exception(f"Error querying collection '{collection_name}': {e}")
return None
def get(self, collection_name: str) -> Optional[GetResult]:
"""
@ -311,36 +507,169 @@ class QdrantClient(VectorDBBase):
"""
if not self.client:
return None
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection):
log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
return None
tenant_filter = _tenant_filter(tenant_id)
points = self.client.scroll(
collection_name=mt_collection,
scroll_filter=models.Filter(must=[tenant_filter]),
limit=NO_LIMIT,
)
return self._result_to_get_result(points[0])
def upsert(self, collection_name: str, items: List[VectorItem]):
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
try:
# Try to get points directly - most of the time collection should exist
points = self.client.query_points(
collection_name=mt_collection,
query_filter=models.Filter(must=[tenant_filter]),
limit=NO_LIMIT,
)
return self._result_to_get_result(points.points)
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error during get: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, log and return None
log.exception(f"Error getting collection '{collection_name}': {e}")
return None
def _handle_operation_with_error_retry(
self, operation_name, mt_collection, points, dimension
):
"""
Private helper to handle common error cases for insert and upsert operations.
Args:
operation_name: 'insert' or 'upsert'
mt_collection: The multi-tenant collection name
points: The vector points to insert/upsert
dimension: The dimension of the vectors
Returns:
The operation result (for upsert) or None (for insert)
"""
try:
if operation_name == "insert":
self.client.upload_points(mt_collection, points)
return None
else: # upsert
return self.client.upsert(mt_collection, points)
except (UnexpectedResponse, grpc.RpcError) as e:
# Handle collection not found
if self._is_collection_not_found_error(e):
log.info(
f"Collection {mt_collection} doesn't exist. Creating it with dimension {dimension}."
)
# Create collection with correct dimensions from our vectors
self._create_multi_tenant_collection_if_not_exists(
mt_collection_name=mt_collection, dimension=dimension
)
# Try operation again - no need for dimension adjustment since we just created with correct dimensions
if operation_name == "insert":
self.client.upload_points(mt_collection, points)
return None
else: # upsert
return self.client.upsert(mt_collection, points)
# Handle dimension mismatch
elif self._is_dimension_mismatch_error(e):
# For dimension errors, the collection must exist, so get its configuration
mt_collection_info = self.client.get_collection(mt_collection)
existing_size = mt_collection_info.config.params.vectors.size
log.info(
f"Dimension mismatch: Collection {mt_collection} expects {existing_size}, got {dimension}"
)
if existing_size < dimension:
# Truncate vectors to fit
log.info(
f"Truncating vectors from {dimension} to {existing_size} dimensions"
)
points = [
PointStruct(
id=point.id,
vector=point.vector[:existing_size],
payload=point.payload,
)
for point in points
]
elif existing_size > dimension:
# Pad vectors with zeros
log.info(
f"Padding vectors from {dimension} to {existing_size} dimensions with zeros"
)
points = [
PointStruct(
id=point.id,
vector=point.vector
+ [0] * (existing_size - len(point.vector)),
payload=point.payload,
)
for point in points
]
# Try operation again with adjusted dimensions
if operation_name == "insert":
self.client.upload_points(mt_collection, points)
return None
else: # upsert
return self.client.upsert(mt_collection, points)
else:
# Not a known error we can handle, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unhandled Qdrant error: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, re-raise
raise
def insert(self, collection_name: str, items: list[VectorItem]):
"""
Insert items with tenant ID.
"""
if not self.client or not items:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Get dimensions from the actual vectors
dimension = len(items[0]["vector"]) if items else None
# Create points with tenant ID
points = self._create_points(items, tenant_id)
# Handle the operation with error retry
return self._handle_operation_with_error_retry(
"insert", mt_collection, points, dimension
)
def upsert(self, collection_name: str, items: list[VectorItem]):
"""
Upsert items with tenant ID.
"""
if not self.client or not items:
return None
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
dimension = len(items[0]["vector"])
self._ensure_collection(mt_collection, dimension)
points = self._create_points(items, tenant_id)
self.client.upload_points(mt_collection, points)
return None
def insert(self, collection_name: str, items: List[VectorItem]):
"""
Insert items with tenant ID.
"""
return self.upsert(collection_name, items)
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Get dimensions from the actual vectors
dimension = len(items[0]["vector"]) if items else None
# Create points with tenant ID
points = self._create_points(items, tenant_id)
# Handle the operation with error retry
return self._handle_operation_with_error_retry(
"upsert", mt_collection, points, dimension
)
def reset(self):
"""
@ -348,9 +677,11 @@ class QdrantClient(VectorDBBase):
"""
if not self.client:
return None
for collection in self.client.get_collections().collections:
if collection.name.startswith(self.collection_prefix):
self.client.delete_collection(collection_name=collection.name)
collection_names = self.client.get_collections().collections
for collection_name in collection_names:
if collection_name.name.startswith(self.collection_prefix):
self.client.delete_collection(collection_name=collection_name.name)
def delete_collection(self, collection_name: str):
"""
@ -358,13 +689,24 @@ class QdrantClient(VectorDBBase):
"""
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection):
log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
return None
self.client.delete(
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
field_conditions = [tenant_filter]
update_result = self.client.delete(
collection_name=mt_collection,
points_selector=models.FilterSelector(
filter=models.Filter(must=[_tenant_filter(tenant_id)])
filter=models.Filter(must=field_conditions)
),
)
if self.client.get_collection(mt_collection).points_count == 0:
self.client.delete_collection(mt_collection)
return update_result

View file

@ -1,774 +0,0 @@
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
GetResult,
SearchResult,
)
from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION
from typing import List, Optional, Dict, Any, Union
import logging
import boto3
log = logging.getLogger(__name__)
class S3VectorClient(VectorDBBase):
"""
AWS S3 Vector integration for Open WebUI Knowledge.
"""
def __init__(self):
self.bucket_name = S3_VECTOR_BUCKET_NAME
self.region = S3_VECTOR_REGION
# Simple validation - log warnings instead of raising exceptions
if not self.bucket_name:
log.warning("S3_VECTOR_BUCKET_NAME not set - S3Vector will not work")
if not self.region:
log.warning("S3_VECTOR_REGION not set - S3Vector will not work")
if self.bucket_name and self.region:
try:
self.client = boto3.client("s3vectors", region_name=self.region)
log.info(
f"S3Vector client initialized for bucket '{self.bucket_name}' in region '{self.region}'"
)
except Exception as e:
log.error(f"Failed to initialize S3Vector client: {e}")
self.client = None
else:
self.client = None
def _create_index(
self,
index_name: str,
dimension: int,
data_type: str = "float32",
distance_metric: str = "cosine",
) -> None:
"""
Create a new index in the S3 vector bucket for the given collection if it does not exist.
"""
if self.has_collection(index_name):
log.debug(f"Index '{index_name}' already exists, skipping creation")
return
try:
self.client.create_index(
vectorBucketName=self.bucket_name,
indexName=index_name,
dataType=data_type,
dimension=dimension,
distanceMetric=distance_metric,
)
log.info(
f"Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})"
)
except Exception as e:
log.error(f"Error creating S3 index '{index_name}': {e}")
raise
def _filter_metadata(
self, metadata: Dict[str, Any], item_id: str
) -> Dict[str, Any]:
"""
Filter vector metadata keys to comply with S3 Vector API limit of 10 keys maximum.
"""
if not isinstance(metadata, dict) or len(metadata) <= 10:
return metadata
# Keep only the first 10 keys, prioritizing important ones based on actual Open WebUI metadata
important_keys = [
"text", # The actual document content
"file_id", # File ID
"source", # Document source file
"title", # Document title
"page", # Page number
"total_pages", # Total pages in document
"embedding_config", # Embedding configuration
"created_by", # User who created it
"name", # Document name
"hash", # Content hash
]
filtered_metadata = {}
# First, add important keys if they exist
for key in important_keys:
if key in metadata:
filtered_metadata[key] = metadata[key]
if len(filtered_metadata) >= 10:
break
# If we still have room, add other keys
if len(filtered_metadata) < 10:
for key, value in metadata.items():
if key not in filtered_metadata:
filtered_metadata[key] = value
if len(filtered_metadata) >= 10:
break
log.warning(
f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys"
)
return filtered_metadata
def has_collection(self, collection_name: str) -> bool:
"""
Check if a vector index exists using direct lookup.
This avoids pagination issues with list_indexes() and is significantly faster.
"""
try:
self.client.get_index(
vectorBucketName=self.bucket_name, indexName=collection_name
)
return True
except Exception as e:
log.error(f"Error checking if index '{collection_name}' exists: {e}")
return False
def delete_collection(self, collection_name: str) -> None:
"""
Delete an entire S3 Vector index/collection.
"""
if not self.has_collection(collection_name):
log.warning(
f"Collection '{collection_name}' does not exist, nothing to delete"
)
return
try:
log.info(f"Deleting collection '{collection_name}'")
self.client.delete_index(
vectorBucketName=self.bucket_name, indexName=collection_name
)
log.info(f"Successfully deleted collection '{collection_name}'")
except Exception as e:
log.error(f"Error deleting collection '{collection_name}': {e}")
raise
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
"""
Insert vector items into the S3 Vector index. Create index if it does not exist.
"""
if not items:
log.warning("No items to insert")
return
dimension = len(items[0]["vector"])
try:
if not self.has_collection(collection_name):
log.info(f"Index '{collection_name}' does not exist. Creating index.")
self._create_index(
index_name=collection_name,
dimension=dimension,
data_type="float32",
distance_metric="cosine",
)
# Prepare vectors for insertion
vectors = []
for item in items:
# Ensure vector data is in the correct format for S3 Vector API
vector_data = item["vector"]
if isinstance(vector_data, list):
# Convert list to float32 values as required by S3 Vector API
vector_data = [float(x) for x in vector_data]
# Prepare metadata, ensuring the text field is preserved
metadata = item.get("metadata", {}).copy()
# Add the text field to metadata so it's available for retrieval
metadata["text"] = item["text"]
# Convert metadata to string format for consistency
metadata = process_metadata(metadata)
# Filter metadata to comply with S3 Vector API limit of 10 keys
metadata = self._filter_metadata(metadata, item["id"])
vectors.append(
{
"key": item["id"],
"data": {"float32": vector_data},
"metadata": metadata,
}
)
# Insert vectors in batches of 500 (S3 Vector API limit)
batch_size = 500
for i in range(0, len(vectors), batch_size):
batch = vectors[i : i + batch_size]
self.client.put_vectors(
vectorBucketName=self.bucket_name,
indexName=collection_name,
vectors=batch,
)
log.info(
f"Inserted batch {i//batch_size + 1}: {len(batch)} vectors into index '{collection_name}'."
)
log.info(
f"Completed insertion of {len(vectors)} vectors into index '{collection_name}'."
)
except Exception as e:
log.error(f"Error inserting vectors: {e}")
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
"""
Insert or update vector items in the S3 Vector index. Create index if it does not exist.
"""
if not items:
log.warning("No items to upsert")
return
dimension = len(items[0]["vector"])
log.info(f"Upsert dimension: {dimension}")
try:
if not self.has_collection(collection_name):
log.info(
f"Index '{collection_name}' does not exist. Creating index for upsert."
)
self._create_index(
index_name=collection_name,
dimension=dimension,
data_type="float32",
distance_metric="cosine",
)
# Prepare vectors for upsert
vectors = []
for item in items:
# Ensure vector data is in the correct format for S3 Vector API
vector_data = item["vector"]
if isinstance(vector_data, list):
# Convert list to float32 values as required by S3 Vector API
vector_data = [float(x) for x in vector_data]
# Prepare metadata, ensuring the text field is preserved
metadata = item.get("metadata", {}).copy()
# Add the text field to metadata so it's available for retrieval
metadata["text"] = item["text"]
# Convert metadata to string format for consistency
metadata = process_metadata(metadata)
# Filter metadata to comply with S3 Vector API limit of 10 keys
metadata = self._filter_metadata(metadata, item["id"])
vectors.append(
{
"key": item["id"],
"data": {"float32": vector_data},
"metadata": metadata,
}
)
# Upsert vectors in batches of 500 (S3 Vector API limit)
batch_size = 500
for i in range(0, len(vectors), batch_size):
batch = vectors[i : i + batch_size]
if i == 0: # Log sample info for first batch only
log.info(
f"Upserting batch 1: {len(batch)} vectors. First vector sample: key={batch[0]['key']}, data_type={type(batch[0]['data']['float32'])}, data_len={len(batch[0]['data']['float32'])}"
)
else:
log.info(
f"Upserting batch {i//batch_size + 1}: {len(batch)} vectors."
)
self.client.put_vectors(
vectorBucketName=self.bucket_name,
indexName=collection_name,
vectors=batch,
)
log.info(
f"Completed upsert of {len(vectors)} vectors into index '{collection_name}'."
)
except Exception as e:
log.error(f"Error upserting vectors: {e}")
raise
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
) -> Optional[SearchResult]:
"""
Search for similar vectors in a collection using multiple query vectors.
"""
if not self.has_collection(collection_name):
log.warning(f"Collection '{collection_name}' does not exist")
return None
if not vectors:
log.warning("No query vectors provided")
return None
try:
log.info(
f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}"
)
# Initialize result lists
all_ids = []
all_documents = []
all_metadatas = []
all_distances = []
# Process each query vector
for i, query_vector in enumerate(vectors):
log.debug(f"Processing query vector {i+1}/{len(vectors)}")
# Prepare the query vector in S3 Vector format
query_vector_dict = {"float32": [float(x) for x in query_vector]}
# Call S3 Vector query API
response = self.client.query_vectors(
vectorBucketName=self.bucket_name,
indexName=collection_name,
topK=limit,
queryVector=query_vector_dict,
returnMetadata=True,
returnDistance=True,
)
# Process results for this query
query_ids = []
query_documents = []
query_metadatas = []
query_distances = []
result_vectors = response.get("vectors", [])
for vector in result_vectors:
vector_id = vector.get("key")
vector_metadata = vector.get("metadata", {})
vector_distance = vector.get("distance", 0.0)
# Extract document text from metadata
document_text = ""
if isinstance(vector_metadata, dict):
# Get the text field first (highest priority)
document_text = vector_metadata.get("text")
if not document_text:
# Fallback to other possible text fields
document_text = (
vector_metadata.get("content")
or vector_metadata.get("document")
or vector_id
)
else:
document_text = vector_id
query_ids.append(vector_id)
query_documents.append(document_text)
query_metadatas.append(vector_metadata)
query_distances.append(vector_distance)
# Add this query's results to the overall results
all_ids.append(query_ids)
all_documents.append(query_documents)
all_metadatas.append(query_metadatas)
all_distances.append(query_distances)
log.info(f"Search completed. Found results for {len(all_ids)} queries")
# Return SearchResult format
return SearchResult(
ids=all_ids if all_ids else None,
documents=all_documents if all_documents else None,
metadatas=all_metadatas if all_metadatas else None,
distances=all_distances if all_distances else None,
)
except Exception as e:
log.error(f"Error searching collection '{collection_name}': {str(e)}")
# Handle specific AWS exceptions
if hasattr(e, "response") and "Error" in e.response:
error_code = e.response["Error"]["Code"]
if error_code == "NotFoundException":
log.warning(f"Collection '{collection_name}' not found")
return None
elif error_code == "ValidationException":
log.error(f"Invalid query vector dimensions or parameters")
return None
elif error_code == "AccessDeniedException":
log.error(
f"Access denied for collection '{collection_name}'. Check permissions."
)
return None
raise
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
"""
Query vectors from a collection using metadata filter.
"""
if not self.has_collection(collection_name):
log.warning(f"Collection '{collection_name}' does not exist")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
if not filter:
log.warning("No filter provided, returning all vectors")
return self.get(collection_name)
try:
log.info(f"Querying collection '{collection_name}' with filter: {filter}")
# For S3 Vector, we need to use list_vectors and then filter results
# Since S3 Vector may not support complex server-side filtering,
# we'll retrieve all vectors and filter client-side
# Get all vectors first
all_vectors_result = self.get(collection_name)
if not all_vectors_result or not all_vectors_result.ids:
log.warning("No vectors found in collection")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
# Extract the lists from the result
all_ids = all_vectors_result.ids[0] if all_vectors_result.ids else []
all_documents = (
all_vectors_result.documents[0] if all_vectors_result.documents else []
)
all_metadatas = (
all_vectors_result.metadatas[0] if all_vectors_result.metadatas else []
)
# Apply client-side filtering
filtered_ids = []
filtered_documents = []
filtered_metadatas = []
for i, metadata in enumerate(all_metadatas):
if self._matches_filter(metadata, filter):
if i < len(all_ids):
filtered_ids.append(all_ids[i])
if i < len(all_documents):
filtered_documents.append(all_documents[i])
filtered_metadatas.append(metadata)
# Apply limit if specified
if limit and len(filtered_ids) >= limit:
break
log.info(
f"Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total"
)
# Return GetResult format
if filtered_ids:
return GetResult(
ids=[filtered_ids],
documents=[filtered_documents],
metadatas=[filtered_metadatas],
)
else:
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
except Exception as e:
log.error(f"Error querying collection '{collection_name}': {str(e)}")
# Handle specific AWS exceptions
if hasattr(e, "response") and "Error" in e.response:
error_code = e.response["Error"]["Code"]
if error_code == "NotFoundException":
log.warning(f"Collection '{collection_name}' not found")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
elif error_code == "AccessDeniedException":
log.error(
f"Access denied for collection '{collection_name}'. Check permissions."
)
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
raise
def get(self, collection_name: str) -> Optional[GetResult]:
"""
Retrieve all vectors from a collection.
"""
if not self.has_collection(collection_name):
log.warning(f"Collection '{collection_name}' does not exist")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
try:
log.info(f"Retrieving all vectors from collection '{collection_name}'")
# Initialize result lists
all_ids = []
all_documents = []
all_metadatas = []
# Handle pagination
next_token = None
while True:
# Prepare request parameters
request_params = {
"vectorBucketName": self.bucket_name,
"indexName": collection_name,
"returnData": False, # Don't include vector data (not needed for get)
"returnMetadata": True, # Include metadata
"maxResults": 500, # Use reasonable page size
}
if next_token:
request_params["nextToken"] = next_token
# Call S3 Vector API
response = self.client.list_vectors(**request_params)
# Process vectors in this page
vectors = response.get("vectors", [])
for vector in vectors:
vector_id = vector.get("key")
vector_data = vector.get("data", {})
vector_metadata = vector.get("metadata", {})
# Extract the actual vector array
vector_array = vector_data.get("float32", [])
# For documents, we try to extract text from metadata or use the vector ID
document_text = ""
if isinstance(vector_metadata, dict):
# Get the text field first (highest priority)
document_text = vector_metadata.get("text")
if not document_text:
# Fallback to other possible text fields
document_text = (
vector_metadata.get("content")
or vector_metadata.get("document")
or vector_id
)
# Log the actual content for debugging
log.debug(
f"Document text preview (first 200 chars): {str(document_text)[:200]}"
)
else:
document_text = vector_id
all_ids.append(vector_id)
all_documents.append(document_text)
all_metadatas.append(vector_metadata)
# Check if there are more pages
next_token = response.get("nextToken")
if not next_token:
break
log.info(
f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'"
)
# Return in GetResult format
# The Open WebUI GetResult expects lists of lists, so we wrap each list
if all_ids:
return GetResult(
ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas]
)
else:
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
except Exception as e:
log.error(
f"Error retrieving vectors from collection '{collection_name}': {str(e)}"
)
# Handle specific AWS exceptions
if hasattr(e, "response") and "Error" in e.response:
error_code = e.response["Error"]["Code"]
if error_code == "NotFoundException":
log.warning(f"Collection '{collection_name}' not found")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
elif error_code == "AccessDeniedException":
log.error(
f"Access denied for collection '{collection_name}'. Check permissions."
)
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
raise
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict] = None,
) -> None:
"""
Delete vectors by ID or filter from a collection.
"""
if not self.has_collection(collection_name):
log.warning(
f"Collection '{collection_name}' does not exist, nothing to delete"
)
return
# Check if this is a knowledge collection (not file-specific)
is_knowledge_collection = not collection_name.startswith("file-")
try:
if ids:
# Delete by specific vector IDs/keys
log.info(
f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'"
)
self.client.delete_vectors(
vectorBucketName=self.bucket_name,
indexName=collection_name,
keys=ids,
)
log.info(f"Deleted {len(ids)} vectors from index '{collection_name}'")
elif filter:
# Handle filter-based deletion
log.info(
f"Deleting vectors by filter from collection '{collection_name}': {filter}"
)
# If this is a knowledge collection and we have a file_id filter,
# also clean up the corresponding file-specific collection
if is_knowledge_collection and "file_id" in filter:
file_id = filter["file_id"]
file_collection_name = f"file-{file_id}"
if self.has_collection(file_collection_name):
log.info(
f"Found related file-specific collection '{file_collection_name}', deleting it to prevent duplicates"
)
self.delete_collection(file_collection_name)
# For the main collection, implement query-then-delete
# First, query to get IDs matching the filter
query_result = self.query(collection_name, filter)
if query_result and query_result.ids and query_result.ids[0]:
matching_ids = query_result.ids[0]
log.info(
f"Found {len(matching_ids)} vectors matching filter, deleting them"
)
# Delete the matching vectors by ID
self.client.delete_vectors(
vectorBucketName=self.bucket_name,
indexName=collection_name,
keys=matching_ids,
)
log.info(
f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter"
)
else:
log.warning("No vectors found matching the filter criteria")
else:
log.warning("No IDs or filter provided for deletion")
except Exception as e:
log.error(
f"Error deleting vectors from collection '{collection_name}': {e}"
)
raise
def reset(self) -> None:
"""
Reset/clear all vector data. For S3 Vector, this deletes all indexes.
"""
try:
log.warning(
"Reset called - this will delete all vector indexes in the S3 bucket"
)
# List all indexes
response = self.client.list_indexes(vectorBucketName=self.bucket_name)
indexes = response.get("indexes", [])
if not indexes:
log.warning("No indexes found to delete")
return
# Delete all indexes
deleted_count = 0
for index in indexes:
index_name = index.get("indexName")
if index_name:
try:
self.client.delete_index(
vectorBucketName=self.bucket_name, indexName=index_name
)
deleted_count += 1
log.info(f"Deleted index: {index_name}")
except Exception as e:
log.error(f"Error deleting index '{index_name}': {e}")
log.info(f"Reset completed: deleted {deleted_count} indexes")
except Exception as e:
log.error(f"Error during reset: {e}")
raise
def _matches_filter(self, metadata: Dict[str, Any], filter: Dict[str, Any]) -> bool:
"""
Check if metadata matches the given filter conditions.
"""
if not isinstance(metadata, dict) or not isinstance(filter, dict):
return False
# Check each filter condition
for key, expected_value in filter.items():
# Handle special operators
if key.startswith("$"):
if key == "$and":
# All conditions must match
if not isinstance(expected_value, list):
continue
for condition in expected_value:
if not self._matches_filter(metadata, condition):
return False
elif key == "$or":
# At least one condition must match
if not isinstance(expected_value, list):
continue
any_match = False
for condition in expected_value:
if self._matches_filter(metadata, condition):
any_match = True
break
if not any_match:
return False
continue
# Get the actual value from metadata
actual_value = metadata.get(key)
# Handle different types of expected values
if isinstance(expected_value, dict):
# Handle comparison operators
for op, op_value in expected_value.items():
if op == "$eq":
if actual_value != op_value:
return False
elif op == "$ne":
if actual_value == op_value:
return False
elif op == "$in":
if (
not isinstance(op_value, list)
or actual_value not in op_value
):
return False
elif op == "$nin":
if isinstance(op_value, list) and actual_value in op_value:
return False
elif op == "$exists":
if bool(op_value) != (key in metadata):
return False
# Add more operators as needed
else:
# Simple equality check
if actual_value != expected_value:
return False
return True

View file

@ -1,340 +0,0 @@
import weaviate
import re
import uuid
from typing import Any, Dict, List, Optional, Union
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.config import (
WEAVIATE_HTTP_HOST,
WEAVIATE_HTTP_PORT,
WEAVIATE_GRPC_PORT,
WEAVIATE_API_KEY,
)
def _convert_uuids_to_strings(obj: Any) -> Any:
"""
Recursively convert UUID objects to strings in nested data structures.
This function handles:
- UUID objects -> string
- Dictionaries with UUID values
- Lists/Tuples with UUID values
- Nested combinations of the above
Args:
obj: Any object that might contain UUIDs
Returns:
The same object structure with UUIDs converted to strings
"""
if isinstance(obj, uuid.UUID):
return str(obj)
elif isinstance(obj, dict):
return {key: _convert_uuids_to_strings(value) for key, value in obj.items()}
elif isinstance(obj, (list, tuple)):
return type(obj)(_convert_uuids_to_strings(item) for item in obj)
elif isinstance(obj, (str, int, float, bool, type(None))):
return obj
else:
return obj
class WeaviateClient(VectorDBBase):
def __init__(self):
self.url = WEAVIATE_HTTP_HOST
try:
# Build connection parameters
connection_params = {
"host": WEAVIATE_HTTP_HOST,
"port": WEAVIATE_HTTP_PORT,
"grpc_port": WEAVIATE_GRPC_PORT,
}
# Only add auth_credentials if WEAVIATE_API_KEY exists and is not empty
if WEAVIATE_API_KEY:
connection_params["auth_credentials"] = (
weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY)
)
self.client = weaviate.connect_to_local(**connection_params)
self.client.connect()
except Exception as e:
raise ConnectionError(f"Failed to connect to Weaviate: {e}") from e
def _sanitize_collection_name(self, collection_name: str) -> str:
"""Sanitize collection name to be a valid Weaviate class name."""
if not isinstance(collection_name, str) or not collection_name.strip():
raise ValueError("Collection name must be a non-empty string")
# Requirements for a valid Weaviate class name:
# The collection name must begin with a capital letter.
# The name can only contain letters, numbers, and the underscore (_) character. Spaces are not allowed.
# Replace hyphens with underscores and keep only alphanumeric characters
name = re.sub(r"[^a-zA-Z0-9_]", "", collection_name.replace("-", "_"))
name = name.strip("_")
if not name:
raise ValueError(
"Could not sanitize collection name to be a valid Weaviate class name"
)
# Ensure it starts with a letter and is capitalized
if not name[0].isalpha():
name = "C" + name
return name[0].upper() + name[1:]
def has_collection(self, collection_name: str) -> bool:
sane_collection_name = self._sanitize_collection_name(collection_name)
return self.client.collections.exists(sane_collection_name)
def delete_collection(self, collection_name: str) -> None:
sane_collection_name = self._sanitize_collection_name(collection_name)
if self.client.collections.exists(sane_collection_name):
self.client.collections.delete(sane_collection_name)
def _create_collection(self, collection_name: str) -> None:
self.client.collections.create(
name=collection_name,
vector_config=weaviate.classes.config.Configure.Vectors.self_provided(),
properties=[
weaviate.classes.config.Property(
name="text", data_type=weaviate.classes.config.DataType.TEXT
),
],
)
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
self._create_collection(sane_collection_name)
collection = self.client.collections.get(sane_collection_name)
with collection.batch.fixed_size(batch_size=100) as batch:
for item in items:
item_uuid = str(uuid.uuid4()) if not item["id"] else str(item["id"])
properties = {"text": item["text"]}
if item["metadata"]:
clean_metadata = _convert_uuids_to_strings(
process_metadata(item["metadata"])
)
clean_metadata.pop("text", None)
properties.update(clean_metadata)
batch.add_object(
properties=properties, uuid=item_uuid, vector=item["vector"]
)
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
self._create_collection(sane_collection_name)
collection = self.client.collections.get(sane_collection_name)
with collection.batch.fixed_size(batch_size=100) as batch:
for item in items:
item_uuid = str(item["id"]) if item["id"] else None
properties = {"text": item["text"]}
if item["metadata"]:
clean_metadata = _convert_uuids_to_strings(
process_metadata(item["metadata"])
)
clean_metadata.pop("text", None)
properties.update(clean_metadata)
batch.add_object(
properties=properties, uuid=item_uuid, vector=item["vector"]
)
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
) -> Optional[SearchResult]:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
return None
collection = self.client.collections.get(sane_collection_name)
result_ids, result_documents, result_metadatas, result_distances = (
[],
[],
[],
[],
)
for vector_embedding in vectors:
try:
response = collection.query.near_vector(
near_vector=vector_embedding,
limit=limit,
return_metadata=weaviate.classes.query.MetadataQuery(distance=True),
)
ids = [str(obj.uuid) for obj in response.objects]
documents = []
metadatas = []
distances = []
for obj in response.objects:
properties = dict(obj.properties) if obj.properties else {}
documents.append(properties.pop("text", ""))
metadatas.append(_convert_uuids_to_strings(properties))
# Weaviate has cosine distance, 2 (worst) -> 0 (best). Re-ordering to 0 -> 1
raw_distances = [
(
obj.metadata.distance
if obj.metadata and obj.metadata.distance
else 2.0
)
for obj in response.objects
]
distances = [(2 - dist) / 2 for dist in raw_distances]
result_ids.append(ids)
result_documents.append(documents)
result_metadatas.append(metadatas)
result_distances.append(distances)
except Exception:
result_ids.append([])
result_documents.append([])
result_metadatas.append([])
result_distances.append([])
return SearchResult(
**{
"ids": result_ids,
"documents": result_documents,
"metadatas": result_metadatas,
"distances": result_distances,
}
)
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
return None
collection = self.client.collections.get(sane_collection_name)
weaviate_filter = None
if filter:
for key, value in filter.items():
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(
value
)
weaviate_filter = (
prop_filter
if weaviate_filter is None
else weaviate.classes.query.Filter.all_of(
[weaviate_filter, prop_filter]
)
)
try:
response = collection.query.fetch_objects(
filters=weaviate_filter, limit=limit
)
ids = [str(obj.uuid) for obj in response.objects]
documents = []
metadatas = []
for obj in response.objects:
properties = dict(obj.properties) if obj.properties else {}
documents.append(properties.pop("text", ""))
metadatas.append(_convert_uuids_to_strings(properties))
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
except Exception:
return None
def get(self, collection_name: str) -> Optional[GetResult]:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
return None
collection = self.client.collections.get(sane_collection_name)
ids, documents, metadatas = [], [], []
try:
for item in collection.iterator():
ids.append(str(item.uuid))
properties = dict(item.properties) if item.properties else {}
documents.append(properties.pop("text", ""))
metadatas.append(_convert_uuids_to_strings(properties))
if not ids:
return None
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
except Exception:
return None
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict] = None,
) -> None:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
return
collection = self.client.collections.get(sane_collection_name)
try:
if ids:
for item_id in ids:
collection.data.delete_by_id(uuid=item_id)
elif filter:
weaviate_filter = None
for key, value in filter.items():
prop_filter = weaviate.classes.query.Filter.by_property(
name=key
).equal(value)
weaviate_filter = (
prop_filter
if weaviate_filter is None
else weaviate.classes.query.Filter.all_of(
[weaviate_filter, prop_filter]
)
)
if weaviate_filter:
collection.data.delete_many(where=weaviate_filter)
except Exception:
pass
def reset(self) -> None:
try:
for collection_name in self.client.collections.list_all().keys():
self.client.collections.delete(collection_name)
except Exception:
pass

View file

@ -1,10 +1,6 @@
from open_webui.retrieval.vector.main import VectorDBBase
from open_webui.retrieval.vector.type import VectorType
from open_webui.config import (
VECTOR_DB,
ENABLE_QDRANT_MULTITENANCY_MODE,
ENABLE_MILVUS_MULTITENANCY_MODE,
)
from open_webui.config import VECTOR_DB, ENABLE_QDRANT_MULTITENANCY_MODE
class Vector:
@ -16,16 +12,9 @@ class Vector:
"""
match vector_type:
case VectorType.MILVUS:
if ENABLE_MILVUS_MULTITENANCY_MODE:
from open_webui.retrieval.vector.dbs.milvus_multitenancy import (
MilvusClient,
)
from open_webui.retrieval.vector.dbs.milvus import MilvusClient
return MilvusClient()
else:
from open_webui.retrieval.vector.dbs.milvus import MilvusClient
return MilvusClient()
return MilvusClient()
case VectorType.QDRANT:
if ENABLE_QDRANT_MULTITENANCY_MODE:
from open_webui.retrieval.vector.dbs.qdrant_multitenancy import (
@ -41,10 +30,6 @@ class Vector:
from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
return PineconeClient()
case VectorType.S3VECTOR:
from open_webui.retrieval.vector.dbs.s3vector import S3VectorClient
return S3VectorClient()
case VectorType.OPENSEARCH:
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
@ -63,14 +48,6 @@ class Vector:
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
return ChromaClient()
case VectorType.ORACLE23AI:
from open_webui.retrieval.vector.dbs.oracle23ai import Oracle23aiClient
return Oracle23aiClient()
case VectorType.WEAVIATE:
from open_webui.retrieval.vector.dbs.weaviate import WeaviateClient
return WeaviateClient()
case _:
raise ValueError(f"Unsupported vector type: {vector_type}")

View file

@ -9,6 +9,3 @@ class VectorType(StrEnum):
ELASTICSEARCH = "elasticsearch"
OPENSEARCH = "opensearch"
PGVECTOR = "pgvector"
ORACLE23AI = "oracle23ai"
S3VECTOR = "s3vector"
WEAVIATE = "weaviate"

View file

@ -1,28 +0,0 @@
from datetime import datetime
KEYS_TO_EXCLUDE = ["content", "pages", "tables", "paragraphs", "sections", "figures"]
def filter_metadata(metadata: dict[str, any]) -> dict[str, any]:
metadata = {
key: value for key, value in metadata.items() if key not in KEYS_TO_EXCLUDE
}
return metadata
def process_metadata(
metadata: dict[str, any],
) -> dict[str, any]:
for key, value in metadata.items():
# Remove large fields
if key in KEYS_TO_EXCLUDE:
del metadata[key]
# Convert non-serializable fields to strings
if (
isinstance(value, datetime)
or isinstance(value, list)
or isinstance(value, dict)
):
metadata[key] = str(value)
return metadata

View file

@ -1,126 +0,0 @@
import logging
from typing import Optional
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
log = logging.getLogger(__name__)
"""
Azure AI Search integration for Open WebUI.
Documentation: https://learn.microsoft.com/en-us/python/api/overview/azure/search-documents-readme?view=azure-python
Required package: azure-search-documents
Install: pip install azure-search-documents
"""
def search_azure(
api_key: str,
endpoint: str,
index_name: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""
Search using Azure AI Search.
Args:
api_key: Azure Search API key (query key or admin key)
endpoint: Azure Search service endpoint (e.g., https://myservice.search.windows.net)
index_name: Name of the search index to query
query: Search query string
count: Number of results to return
filter_list: Optional list of domains to filter results
Returns:
List of SearchResult objects with link, title, and snippet
"""
try:
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
except ImportError:
log.error(
"azure-search-documents package is not installed. "
"Install it with: pip install azure-search-documents"
)
raise ImportError(
"azure-search-documents is required for Azure AI Search. "
"Install it with: pip install azure-search-documents"
)
try:
# Create search client with API key authentication
credential = AzureKeyCredential(api_key)
search_client = SearchClient(
endpoint=endpoint, index_name=index_name, credential=credential
)
# Perform the search
results = search_client.search(search_text=query, top=count)
# Convert results to list and extract fields
search_results = []
for result in results:
# Azure AI Search returns documents with custom schemas
# We need to extract common fields that might represent URL, title, and content
# Common field names to look for:
result_dict = dict(result)
# Try to find URL field (common names)
link = (
result_dict.get("url")
or result_dict.get("link")
or result_dict.get("uri")
or result_dict.get("metadata_storage_path")
or ""
)
# Try to find title field (common names)
title = (
result_dict.get("title")
or result_dict.get("name")
or result_dict.get("metadata_title")
or result_dict.get("metadata_storage_name")
or None
)
# Try to find content/snippet field (common names)
snippet = (
result_dict.get("content")
or result_dict.get("snippet")
or result_dict.get("description")
or result_dict.get("summary")
or result_dict.get("text")
or None
)
# Truncate snippet if too long
if snippet and len(snippet) > 500:
snippet = snippet[:497] + "..."
if link: # Only add if we found a valid link
search_results.append(
{
"link": link,
"title": title,
"snippet": snippet,
}
)
# Apply domain filtering if specified
if filter_list:
search_results = get_filtered_results(search_results, filter_list)
# Convert to SearchResult objects
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
)
for result in search_results
]
except Exception as ex:
log.error(f"Azure AI Search error: {ex}")
raise ex

View file

@ -4,9 +4,11 @@ from pprint import pprint
from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
import argparse
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
"""
Documentation: https://docs.microsoft.com/en-us/bing/search-apis/bing-web-search/overview
"""

View file

@ -4,18 +4,20 @@ from typing import Optional
import requests
import json
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def _parse_response(response):
results = []
result = {}
if "data" in response:
data = response["data"]
if "webPages" in data:
webPages = data["webPages"]
if "value" in webPages:
results = [
result["webpage"] = [
{
"id": item.get("id", ""),
"name": item.get("name", ""),
@ -29,7 +31,7 @@ def _parse_response(response):
}
for item in webPages["value"]
]
return results
return result
def search_bocha(
@ -51,7 +53,7 @@ def search_bocha(
response = requests.post(url, headers=headers, data=payload, timeout=5)
response.raise_for_status()
results = _parse_response(response.json())
print(results)
if filter_list:
results = get_filtered_results(results, filter_list)
@ -59,5 +61,5 @@ def search_bocha(
SearchResult(
link=result["url"], title=result.get("name"), snippet=result.get("summary")
)
for result in results[:count]
for result in results.get("webpage", [])[:count]
]

View file

@ -3,8 +3,10 @@ from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_brave(
@ -34,9 +36,7 @@ def search_brave(
return [
SearchResult(
link=result["url"],
title=result.get("title"),
snippet=result.get("description"),
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
)
for result in results[:count]
]

View file

@ -2,17 +2,16 @@ import logging
from typing import Optional
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from ddgs import DDGS
from ddgs.exceptions import RatelimitException
from duckduckgo_search import DDGS
from duckduckgo_search.exceptions import RatelimitException
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_duckduckgo(
query: str,
count: int,
filter_list: Optional[list[str]] = None,
concurrent_requests: Optional[int] = None,
query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
"""
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
@ -26,9 +25,6 @@ def search_duckduckgo(
# Use the DDGS context manager to create a DDGS object
search_results = []
with DDGS() as ddgs:
if concurrent_requests:
ddgs.threads = concurrent_requests
# Use the ddgs.text() method to perform the search
try:
search_results = ddgs.text(

View file

@ -3,9 +3,11 @@ from dataclasses import dataclass
from typing import Optional
import requests
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.web.main import SearchResult
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
EXA_API_BASE = "https://api.exa.ai"

View file

@ -2,40 +2,27 @@ import logging
from typing import Optional, List
import requests
from fastapi import Request
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.utils.headers import include_user_info_headers
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_external(
request: Request,
external_url: str,
external_api_key: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
user=None,
) -> List[SearchResult]:
try:
headers = {
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Authorization": f"Bearer {external_api_key}",
}
headers = include_user_info_headers(headers, user)
chat_id = getattr(request.state, "chat_id", None)
if chat_id:
headers["X-OpenWebUI-Chat-Id"] = str(chat_id)
response = requests.post(
external_url,
headers=headers,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Authorization": f"Bearer {external_api_key}",
},
json={
"query": query,
"count": count,

View file

@ -1,10 +1,13 @@
import logging
from typing import Optional, List
from urllib.parse import urljoin
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_firecrawl(
@ -15,20 +18,27 @@ def search_firecrawl(
filter_list: Optional[List[str]] = None,
) -> List[SearchResult]:
try:
from firecrawl import FirecrawlApp
firecrawl = FirecrawlApp(api_key=firecrawl_api_key, api_url=firecrawl_url)
response = firecrawl.search(
query=query, limit=count, ignore_invalid_urls=True, timeout=count * 3
firecrawl_search_url = urljoin(firecrawl_url, "/v1/search")
response = requests.post(
firecrawl_search_url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Authorization": f"Bearer {firecrawl_api_key}",
},
json={
"query": query,
"limit": count,
},
)
results = response.web
response.raise_for_status()
results = response.json().get("data", [])
if filter_list:
results = get_filtered_results(results, filter_list)
results = [
SearchResult(
link=result.url,
title=result.title,
snippet=result.description,
link=result.get("url"),
title=result.get("title"),
snippet=result.get("description"),
)
for result in results[:count]
]

View file

@ -3,8 +3,10 @@ from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_google_pse(
@ -13,7 +15,6 @@ def search_google_pse(
query: str,
count: int,
filter_list: Optional[list[str]] = None,
referer: Optional[str] = None,
) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
Handles pagination for counts greater than 10.
@ -29,11 +30,7 @@ def search_google_pse(
list[SearchResult]: A list of SearchResult objects.
"""
url = "https://www.googleapis.com/customsearch/v1"
headers = {"Content-Type": "application/json"}
if referer:
headers["Referer"] = referer
all_results = []
start_index = 1 # Google PSE start parameter is 1-based

View file

@ -2,9 +2,11 @@ import logging
import requests
from open_webui.retrieval.web.main import SearchResult
from open_webui.env import SRC_LOG_LEVELS
from yarl import URL
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:

View file

@ -3,8 +3,10 @@ from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_kagi(

Some files were not shown because too many files have changed in this diff Show more