diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py
index 87df03238f..8a9d410e6d 100644
--- a/backend/open_webui/apps/retrieval/main.py
+++ b/backend/open_webui/apps/retrieval/main.py
@@ -709,8 +709,8 @@ def save_docs_to_vector_db(
if overwrite:
VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
log.info(f"deleting existing collection {collection_name}")
-
- if add is False:
+ elif add is False:
+ log.info(f"collection {collection_name} already exists, overwrite is False and add is False")
return True
log.info(f"adding to collection {collection_name}")
diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/apps/retrieval/utils.py
index aa09ec5827..153bd804ff 100644
--- a/backend/open_webui/apps/retrieval/utils.py
+++ b/backend/open_webui/apps/retrieval/utils.py
@@ -385,6 +385,8 @@ def get_rag_context(
extracted_collections.extend(collection_names)
if context:
+ if "data" in file:
+ del file["data"]
relevant_contexts.append({**context, "file": file})
contexts = []
@@ -401,11 +403,8 @@ def get_rag_context(
]
)
)
-
contexts.append(
- (", ".join(file_names) + ":\n\n")
- if file_names
- else ""
+ ((", ".join(file_names) + ":\n\n") if file_names else "")
+ "\n\n".join(
[text for text in context["documents"][0] if text is not None]
)
@@ -423,7 +422,9 @@ def get_rag_context(
except Exception as e:
log.exception(e)
- print(contexts, citations)
+ print("contexts", contexts)
+ print("citations", citations)
+
return contexts, citations
diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py
index 1d12d708eb..94e42f4a80 100644
--- a/backend/open_webui/apps/webui/main.py
+++ b/backend/open_webui/apps/webui/main.py
@@ -9,6 +9,7 @@ from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.routers import (
auths,
chats,
+ folders,
configs,
files,
functions,
@@ -110,6 +111,7 @@ app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
+app.include_router(folders.router, prefix="/folders", tags=["folders"])
app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
diff --git a/backend/open_webui/apps/webui/models/chats.py b/backend/open_webui/apps/webui/models/chats.py
index 509dff9feb..84f8a74d1e 100644
--- a/backend/open_webui/apps/webui/models/chats.py
+++ b/backend/open_webui/apps/webui/models/chats.py
@@ -33,6 +33,7 @@ class Chat(Base):
pinned = Column(Boolean, default=False, nullable=True)
meta = Column(JSON, server_default="{}")
+ folder_id = Column(Text, nullable=True)
class ChatModel(BaseModel):
@@ -51,6 +52,7 @@ class ChatModel(BaseModel):
pinned: Optional[bool] = False
meta: dict = {}
+ folder_id: Optional[str] = None
####################
@@ -61,10 +63,12 @@ class ChatModel(BaseModel):
class ChatForm(BaseModel):
chat: dict
+
class ChatTitleMessagesForm(BaseModel):
title: str
messages: list[dict]
+
class ChatTitleForm(BaseModel):
title: str
@@ -80,6 +84,7 @@ class ChatResponse(BaseModel):
archived: bool
pinned: Optional[bool] = False
meta: dict = {}
+ folder_id: Optional[str] = None
class ChatTitleIdResponse(BaseModel):
@@ -252,14 +257,18 @@ class ChatTable:
limit: int = 50,
) -> list[ChatModel]:
with get_db() as db:
- query = db.query(Chat).filter_by(user_id=user_id)
+ query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
if not include_archived:
query = query.filter_by(archived=False)
- all_chats = (
- query.order_by(Chat.updated_at.desc())
- # .limit(limit).offset(skip)
- .all()
- )
+
+ query = query.order_by(Chat.updated_at.desc())
+
+ if skip:
+ query = query.offset(skip)
+ if limit:
+ query = query.limit(limit)
+
+ all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_title_id_list_by_user_id(
@@ -270,7 +279,9 @@ class ChatTable:
limit: Optional[int] = None,
) -> list[ChatTitleIdResponse]:
with get_db() as db:
- query = db.query(Chat).filter_by(user_id=user_id)
+ query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
+ query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
+
if not include_archived:
query = query.filter_by(archived=False)
@@ -361,7 +372,7 @@ class ChatTable:
with get_db() as db:
all_chats = (
db.query(Chat)
- .filter_by(user_id=user_id, pinned=True)
+ .filter_by(user_id=user_id, pinned=True, archived=False)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
@@ -387,9 +398,25 @@ class ChatTable:
Filters chats based on a search query using Python, allowing pagination using skip and limit.
"""
search_text = search_text.lower().strip()
+
if not search_text:
return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit)
+ search_text_words = search_text.split(" ")
+
+ # search_text might contain 'tag:tag_name' format so we need to extract the tag_name, split the search_text and remove the tags
+ tag_ids = [
+ word.replace("tag:", "").replace(" ", "_").lower()
+ for word in search_text_words
+ if word.startswith("tag:")
+ ]
+
+ search_text_words = [
+ word for word in search_text_words if not word.startswith("tag:")
+ ]
+
+ search_text = " ".join(search_text_words)
+
with get_db() as db:
query = db.query(Chat).filter(Chat.user_id == user_id)
@@ -418,6 +445,26 @@ class ChatTable:
)
).params(search_text=search_text)
)
+
+ # Check if there are any tags to filter, it should have all the tags
+ if tag_ids:
+ query = query.filter(
+ and_(
+ *[
+ text(
+ f"""
+ EXISTS (
+ SELECT 1
+ FROM json_each(Chat.meta, '$.tags') AS tag
+ WHERE tag.value = :tag_id_{tag_idx}
+ )
+ """
+ ).params(**{f"tag_id_{tag_idx}": tag_id})
+ for tag_idx, tag_id in enumerate(tag_ids)
+ ]
+ )
+ )
+
elif dialect_name == "postgresql":
# PostgreSQL relies on proper JSON query for search
query = query.filter(
@@ -436,6 +483,25 @@ class ChatTable:
)
).params(search_text=search_text)
)
+
+ # Check if there are any tags to filter, it should have all the tags
+ if tag_ids:
+ query = query.filter(
+ and_(
+ *[
+ text(
+ f"""
+ EXISTS (
+ SELECT 1
+ FROM json_array_elements_text(Chat.meta->'tags') AS tag
+ WHERE tag = :tag_id_{tag_idx}
+ )
+ """
+ ).params(**{f"tag_id_{tag_idx}": tag_id})
+ for tag_idx, tag_id in enumerate(tag_ids)
+ ]
+ )
+ )
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
@@ -444,9 +510,34 @@ class ChatTable:
# Perform pagination at the SQL level
all_chats = query.offset(skip).limit(limit).all()
+ print(len(all_chats))
+
# Validate and return chats
return [ChatModel.model_validate(chat) for chat in all_chats]
+ def get_chats_by_folder_id_and_user_id(
+ self, folder_id: str, user_id: str
+ ) -> list[ChatModel]:
+ with get_db() as db:
+ all_chats = (
+ db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id).all()
+ )
+ return [ChatModel.model_validate(chat) for chat in all_chats]
+
+ def update_chat_folder_id_by_id_and_user_id(
+ self, id: str, user_id: str, folder_id: str
+ ) -> Optional[ChatModel]:
+ try:
+ with get_db() as db:
+ chat = db.get(Chat, id)
+ chat.folder_id = folder_id
+ chat.updated_at = int(time.time())
+ db.commit()
+ db.refresh(chat)
+ return ChatModel.model_validate(chat)
+ except Exception:
+ return None
+
def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]:
with get_db() as db:
chat = db.get(Chat, id)
@@ -498,7 +589,7 @@ class ChatTable:
if tag_id not in chat.meta.get("tags", []):
chat.meta = {
**chat.meta,
- "tags": chat.meta.get("tags", []) + [tag_id],
+ "tags": list(set(chat.meta.get("tags", []) + [tag_id])),
}
db.commit()
@@ -509,7 +600,7 @@ class ChatTable:
def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int:
with get_db() as db: # Assuming `get_db()` returns a session object
- query = db.query(Chat).filter_by(user_id=user_id)
+ query = db.query(Chat).filter_by(user_id=user_id, archived=False)
# Normalize the tag_name for consistency
tag_id = tag_name.replace(" ", "_").lower()
@@ -555,7 +646,7 @@ class ChatTable:
tags = [tag for tag in tags if tag != tag_id]
chat.meta = {
**chat.meta,
- "tags": tags,
+ "tags": list(set(tags)),
}
db.commit()
return True
diff --git a/backend/open_webui/apps/webui/models/folders.py b/backend/open_webui/apps/webui/models/folders.py
new file mode 100644
index 0000000000..91aa0175e3
--- /dev/null
+++ b/backend/open_webui/apps/webui/models/folders.py
@@ -0,0 +1,225 @@
+import logging
+import time
+import uuid
+from typing import Optional
+
+from open_webui.apps.webui.internal.db import Base, get_db
+
+
+from open_webui.env import SRC_LOG_LEVELS
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MODELS"])
+
+
+####################
+# Folder DB Schema
+####################
+
+
+class Folder(Base):
+ __tablename__ = "folder"
+ id = Column(Text, primary_key=True)
+ parent_id = Column(Text, nullable=True)
+ user_id = Column(Text)
+ name = Column(Text)
+ items = Column(JSON, nullable=True)
+ meta = Column(JSON, nullable=True)
+ is_expanded = Column(Boolean, default=False)
+ created_at = Column(BigInteger)
+ updated_at = Column(BigInteger)
+
+
+class FolderModel(BaseModel):
+ id: str
+ parent_id: Optional[str] = None
+ user_id: str
+ name: str
+ items: Optional[dict] = None
+ meta: Optional[dict] = None
+ is_expanded: bool = False
+ created_at: int
+ updated_at: int
+
+ model_config = ConfigDict(from_attributes=True)
+
+
+####################
+# Forms
+####################
+
+
+class FolderForm(BaseModel):
+ name: str
+ model_config = ConfigDict(extra="allow")
+
+
+class FolderTable:
+ def insert_new_folder(
+ self, user_id: str, name: str, parent_id: Optional[str] = None
+ ) -> Optional[FolderModel]:
+ with get_db() as db:
+ id = str(uuid.uuid4())
+ folder = FolderModel(
+ **{
+ "id": id,
+ "user_id": user_id,
+ "name": name,
+ "parent_id": parent_id,
+ "created_at": int(time.time()),
+ "updated_at": int(time.time()),
+ }
+ )
+ try:
+ result = Folder(**folder.model_dump())
+ db.add(result)
+ db.commit()
+ db.refresh(result)
+ if result:
+ return FolderModel.model_validate(result)
+ else:
+ return None
+ except Exception as e:
+ print(e)
+ return None
+
+ def get_folder_by_id_and_user_id(
+ self, id: str, user_id: str
+ ) -> Optional[FolderModel]:
+ try:
+ with get_db() as db:
+ folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
+
+ if not folder:
+ return None
+
+ return FolderModel.model_validate(folder)
+ except Exception:
+ return None
+
+ def get_folders_by_user_id(self, user_id: str) -> list[FolderModel]:
+ with get_db() as db:
+ return [
+ FolderModel.model_validate(folder)
+ for folder in db.query(Folder).filter_by(user_id=user_id).all()
+ ]
+
+ def get_folder_by_parent_id_and_user_id_and_name(
+ self, parent_id: Optional[str], user_id: str, name: str
+ ) -> Optional[FolderModel]:
+ try:
+ with get_db() as db:
+ # Check if folder exists
+ folder = (
+ db.query(Folder)
+ .filter_by(parent_id=parent_id, user_id=user_id)
+ .filter(Folder.name.ilike(name))
+ .first()
+ )
+
+ if not folder:
+ return None
+
+ return FolderModel.model_validate(folder)
+ except Exception as e:
+ log.error(f"get_folder_by_parent_id_and_user_id_and_name: {e}")
+ return None
+
+ def get_folders_by_parent_id_and_user_id(
+ self, parent_id: Optional[str], user_id: str
+ ) -> list[FolderModel]:
+ with get_db() as db:
+ return [
+ FolderModel.model_validate(folder)
+ for folder in db.query(Folder)
+ .filter_by(parent_id=parent_id, user_id=user_id)
+ .all()
+ ]
+
+ def update_folder_parent_id_by_id_and_user_id(
+ self,
+ id: str,
+ user_id: str,
+ parent_id: str,
+ ) -> Optional[FolderModel]:
+ try:
+ with get_db() as db:
+ folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
+
+ if not folder:
+ return None
+
+ folder.parent_id = parent_id
+ folder.updated_at = int(time.time())
+
+ db.commit()
+
+ return FolderModel.model_validate(folder)
+ except Exception as e:
+ log.error(f"update_folder: {e}")
+ return
+
+ def update_folder_name_by_id_and_user_id(
+ self, id: str, user_id: str, name: str
+ ) -> Optional[FolderModel]:
+ try:
+ with get_db() as db:
+ folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
+
+ if not folder:
+ return None
+
+ existing_folder = (
+ db.query(Folder)
+ .filter_by(name=name, parent_id=folder.parent_id, user_id=user_id)
+ .first()
+ )
+
+ if existing_folder:
+ return None
+
+ folder.name = name
+ folder.updated_at = int(time.time())
+
+ db.commit()
+
+ return FolderModel.model_validate(folder)
+ except Exception as e:
+ log.error(f"update_folder: {e}")
+ return
+
+ def update_folder_is_expanded_by_id_and_user_id(
+ self, id: str, user_id: str, is_expanded: bool
+ ) -> Optional[FolderModel]:
+ try:
+ with get_db() as db:
+ folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
+
+ if not folder:
+ return None
+
+ folder.is_expanded = is_expanded
+ folder.updated_at = int(time.time())
+
+ db.commit()
+
+ return FolderModel.model_validate(folder)
+ except Exception as e:
+ log.error(f"update_folder: {e}")
+ return
+
+ def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> bool:
+ try:
+ with get_db() as db:
+ folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
+ db.delete(folder)
+ db.commit()
+ return True
+ except Exception as e:
+ log.error(f"delete_folder: {e}")
+ return False
+
+
+Folders = FolderTable()
diff --git a/backend/open_webui/apps/webui/routers/auths.py b/backend/open_webui/apps/webui/routers/auths.py
index e9f94ff6a0..35fd254f12 100644
--- a/backend/open_webui/apps/webui/routers/auths.py
+++ b/backend/open_webui/apps/webui/routers/auths.py
@@ -1,10 +1,13 @@
import re
import uuid
+import time
+import datetime
from open_webui.apps.webui.models.auths import (
AddUserForm,
ApiKey,
Auths,
+ Token,
SigninForm,
SigninResponse,
SignupForm,
@@ -34,6 +37,7 @@ from open_webui.utils.utils import (
get_password_hash,
)
from open_webui.utils.webhook import post_webhook
+from typing import Optional
router = APIRouter()
@@ -42,25 +46,44 @@ router = APIRouter()
############################
-@router.get("/", response_model=UserResponse)
+class SessionUserResponse(Token, UserResponse):
+ expires_at: Optional[int] = None
+
+
+@router.get("/", response_model=SessionUserResponse)
async def get_session_user(
request: Request, response: Response, user=Depends(get_current_user)
):
+ expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
+ expires_at = None
+ if expires_delta:
+ expires_at = int(time.time()) + int(expires_delta.total_seconds())
+
token = create_token(
data={"id": user.id},
- expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
+ expires_delta=expires_delta,
+ )
+
+ datetime_expires_at = (
+ datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
+ if expires_at
+ else None
)
# Set the cookie token
response.set_cookie(
key="token",
value=token,
+ expires=datetime_expires_at,
httponly=True, # Ensures the cookie is not accessible via JavaScript
- samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
- secure=WEBUI_SESSION_COOKIE_SECURE,
+ samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
+ secure=WEBUI_SESSION_COOKIE_SECURE,
)
return {
+ "token": token,
+ "token_type": "Bearer",
+ "expires_at": expires_at,
"id": user.id,
"email": user.email,
"name": user.name,
@@ -119,7 +142,7 @@ async def update_password(
############################
-@router.post("/signin", response_model=SigninResponse)
+@router.post("/signin", response_model=SessionUserResponse)
async def signin(request: Request, response: Response, form_data: SigninForm):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
@@ -161,23 +184,37 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
if user:
+
+ expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
+ expires_at = None
+ if expires_delta:
+ expires_at = int(time.time()) + int(expires_delta.total_seconds())
+
token = create_token(
data={"id": user.id},
- expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
+ expires_delta=expires_delta,
+ )
+
+ datetime_expires_at = (
+ datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
+ if expires_at
+ else None
)
# Set the cookie token
response.set_cookie(
key="token",
value=token,
+ expires=datetime_expires_at,
httponly=True, # Ensures the cookie is not accessible via JavaScript
- samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
- secure=WEBUI_SESSION_COOKIE_SECURE,
+ samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
+ secure=WEBUI_SESSION_COOKIE_SECURE,
)
return {
"token": token,
"token_type": "Bearer",
+ "expires_at": expires_at,
"id": user.id,
"email": user.email,
"name": user.name,
@@ -193,7 +230,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
############################
-@router.post("/signup", response_model=SigninResponse)
+@router.post("/signup", response_model=SessionUserResponse)
async def signup(request: Request, response: Response, form_data: SignupForm):
if WEBUI_AUTH:
if (
@@ -233,18 +270,30 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
)
if user:
+ expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
+ expires_at = None
+ if expires_delta:
+ expires_at = int(time.time()) + int(expires_delta.total_seconds())
+
token = create_token(
data={"id": user.id},
- expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
+ expires_delta=expires_delta,
+ )
+
+ datetime_expires_at = (
+ datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
+ if expires_at
+ else None
)
# Set the cookie token
response.set_cookie(
key="token",
value=token,
+ expires=datetime_expires_at,
httponly=True, # Ensures the cookie is not accessible via JavaScript
- samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
- secure=WEBUI_SESSION_COOKIE_SECURE,
+ samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
+ secure=WEBUI_SESSION_COOKIE_SECURE,
)
if request.app.state.config.WEBHOOK_URL:
@@ -261,6 +310,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
return {
"token": token,
"token_type": "Bearer",
+ "expires_at": expires_at,
"id": user.id,
"email": user.email,
"name": user.name,
diff --git a/backend/open_webui/apps/webui/routers/chats.py b/backend/open_webui/apps/webui/routers/chats.py
index b919d14473..9c404a9a7f 100644
--- a/backend/open_webui/apps/webui/routers/chats.py
+++ b/backend/open_webui/apps/webui/routers/chats.py
@@ -114,13 +114,24 @@ async def search_user_chats(
limit = 60
skip = (page - 1) * limit
- return [
+ chat_list = [
ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_chats_by_user_id_and_search_text(
user.id, text, skip=skip, limit=limit
)
]
+ # Delete tag if no chat is found
+ words = text.strip().split(" ")
+ if page == 1 and len(words) == 1 and words[0].startswith("tag:"):
+ tag_id = words[0].replace("tag:", "")
+ if len(chat_list) == 0:
+ if Tags.get_tag_by_name_and_user_id(tag_id, user.id):
+ log.debug(f"deleting tag: {tag_id}")
+ Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
+
+ return chat_list
+
############################
# GetPinnedChats
@@ -315,7 +326,13 @@ async def update_chat_by_id(
@router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
if user.role == "admin":
+ chat = Chats.get_chat_by_id(id)
+ for tag in chat.meta.get("tags", []):
+ if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
+ Tags.delete_tag_by_name_and_user_id(tag, user.id)
+
result = Chats.delete_chat_by_id(id)
+
return result
else:
if not request.app.state.config.USER_PERMISSIONS.get("chat", {}).get(
@@ -326,6 +343,11 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
+ chat = Chats.get_chat_by_id(id)
+ for tag in chat.meta.get("tags", []):
+ if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
+ Tags.delete_tag_by_name_and_user_id(tag, user.id)
+
result = Chats.delete_chat_by_id_and_user_id(id, user.id)
return result
@@ -397,6 +419,20 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat = Chats.toggle_chat_archive_by_id(id)
+
+ # Delete tags if chat is archived
+ if chat.archived:
+ for tag_id in chat.meta.get("tags", []):
+ if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0:
+ log.debug(f"deleting tag: {tag_id}")
+ Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
+ else:
+ for tag_id in chat.meta.get("tags", []):
+ tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id)
+ if tag is None:
+ log.debug(f"inserting tag: {tag_id}")
+ tag = Tags.insert_new_tag(tag_id, user.id)
+
return ChatResponse(**chat.model_dump())
else:
raise HTTPException(
@@ -455,6 +491,31 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
)
+############################
+# UpdateChatFolderIdById
+############################
+
+
+class ChatFolderIdForm(BaseModel):
+ folder_id: Optional[str] = None
+
+
+@router.post("/{id}/folder", response_model=Optional[ChatResponse])
+async def update_chat_folder_id_by_id(
+ id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user)
+):
+ chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+ if chat:
+ chat = Chats.update_chat_folder_id_by_id_and_user_id(
+ id, user.id, form_data.folder_id
+ )
+ return ChatResponse(**chat.model_dump())
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
+ )
+
+
############################
# GetChatTagsById
############################
diff --git a/backend/open_webui/apps/webui/routers/folders.py b/backend/open_webui/apps/webui/routers/folders.py
new file mode 100644
index 0000000000..08f07b8c64
--- /dev/null
+++ b/backend/open_webui/apps/webui/routers/folders.py
@@ -0,0 +1,259 @@
+import logging
+import os
+import shutil
+import uuid
+from pathlib import Path
+from typing import Optional
+from pydantic import BaseModel
+import mimetypes
+
+
+from open_webui.apps.webui.models.folders import (
+ FolderForm,
+ FolderModel,
+ Folders,
+)
+from open_webui.apps.webui.models.chats import Chats
+
+from open_webui.config import UPLOAD_DIR
+from open_webui.env import SRC_LOG_LEVELS
+from open_webui.constants import ERROR_MESSAGES
+
+
+from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
+from fastapi.responses import FileResponse, StreamingResponse
+
+
+from open_webui.utils.utils import get_admin_user, get_verified_user
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MODELS"])
+
+
+router = APIRouter()
+
+
+############################
+# Get Folders
+############################
+
+
+@router.get("/", response_model=list[FolderModel])
+async def get_folders(user=Depends(get_verified_user)):
+ folders = Folders.get_folders_by_user_id(user.id)
+
+ return [
+ {
+ **folder.model_dump(),
+ "items": {
+ "chats": [
+ {"title": chat.title, "id": chat.id}
+ for chat in Chats.get_chats_by_folder_id_and_user_id(
+ folder.id, user.id
+ )
+ ]
+ },
+ }
+ for folder in folders
+ ]
+
+
+############################
+# Create Folder
+############################
+
+
+@router.post("/")
+def create_folder(form_data: FolderForm, user=Depends(get_verified_user)):
+ folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
+ None, user.id, form_data.name
+ )
+
+ if folder:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Folder already exists"),
+ )
+
+ try:
+ folder = Folders.insert_new_folder(user.id, form_data.name)
+ return folder
+ except Exception as e:
+ log.exception(e)
+ log.error("Error creating folder")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Error creating folder"),
+ )
+
+
+############################
+# Get Folders By Id
+############################
+
+
+@router.get("/{id}", response_model=Optional[FolderModel])
+async def get_folder_by_id(id: str, user=Depends(get_verified_user)):
+ folder = Folders.get_folder_by_id_and_user_id(id, user.id)
+ if folder:
+ return folder
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# Update Folder Name By Id
+############################
+
+
+@router.post("/{id}/update")
+async def update_folder_name_by_id(
+ id: str, form_data: FolderForm, user=Depends(get_verified_user)
+):
+ folder = Folders.get_folder_by_id_and_user_id(id, user.id)
+ if folder:
+ existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
+ folder.parent_id, user.id, form_data.name
+ )
+ if existing_folder:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Folder already exists"),
+ )
+
+ try:
+ folder = Folders.update_folder_name_by_id_and_user_id(
+ id, user.id, form_data.name
+ )
+
+ return folder
+ except Exception as e:
+ log.exception(e)
+ log.error(f"Error updating folder: {id}")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Error updating folder"),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# Update Folder Parent Id By Id
+############################
+
+
+class FolderParentIdForm(BaseModel):
+ parent_id: Optional[str] = None
+
+
+@router.post("/{id}/update/parent")
+async def update_folder_parent_id_by_id(
+ id: str, form_data: FolderParentIdForm, user=Depends(get_verified_user)
+):
+ folder = Folders.get_folder_by_id_and_user_id(id, user.id)
+ if folder:
+ existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
+ form_data.parent_id, user.id, folder.name
+ )
+
+ if existing_folder:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Folder already exists"),
+ )
+
+ try:
+ folder = Folders.update_folder_parent_id_by_id_and_user_id(
+ id, user.id, form_data.parent_id
+ )
+ return folder
+ except Exception as e:
+ log.exception(e)
+ log.error(f"Error updating folder: {id}")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Error updating folder"),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# Update Folder Is Expanded By Id
+############################
+
+
+class FolderIsExpandedForm(BaseModel):
+ is_expanded: bool
+
+
+@router.post("/{id}/update/expanded")
+async def update_folder_is_expanded_by_id(
+ id: str, form_data: FolderIsExpandedForm, user=Depends(get_verified_user)
+):
+ folder = Folders.get_folder_by_id_and_user_id(id, user.id)
+ if folder:
+ try:
+ folder = Folders.update_folder_is_expanded_by_id_and_user_id(
+ id, user.id, form_data.is_expanded
+ )
+ return folder
+ except Exception as e:
+ log.exception(e)
+ log.error(f"Error updating folder: {id}")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Error updating folder"),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# Delete Folder By Id
+############################
+
+
+@router.delete("/{id}")
+async def delete_folder_by_id(id: str, user=Depends(get_verified_user)):
+ folder = Folders.get_folder_by_id_and_user_id(id, user.id)
+ if folder:
+ try:
+ result = Folders.delete_folder_by_id_and_user_id(id, user.id)
+ if result:
+ # Delete all chats in the folder
+ chats = Chats.get_chats_by_folder_id_and_user_id(id, user.id)
+ for chat in chats:
+ Chats.delete_chat_by_id(chat.id, user.id)
+
+ return result
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Error deleting folder"),
+ )
+ except Exception as e:
+ log.exception(e)
+ log.error(f"Error deleting folder: {id}")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Error deleting folder"),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index d55619ee0d..05b27c4471 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -1042,7 +1042,7 @@ CHUNK_OVERLAP = PersistentConfig(
DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules.