mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
refac: batch file processing
Co-Authored-By: Sihyeon Jang <24850223+sihyeonn@users.noreply.github.com>
This commit is contained in:
parent
284764e178
commit
a65cc196a5
2 changed files with 57 additions and 20 deletions
|
|
@ -98,6 +98,13 @@ class FileForm(BaseModel):
|
||||||
access_control: Optional[dict] = None
|
access_control: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
class FileUpdateForm(BaseModel):
|
||||||
|
id: str
|
||||||
|
hash: Optional[str] = None
|
||||||
|
data: Optional[dict] = None
|
||||||
|
meta: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class FilesTable:
|
class FilesTable:
|
||||||
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
|
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
@ -204,6 +211,29 @@ class FilesTable:
|
||||||
for file in db.query(File).filter_by(user_id=user_id).all()
|
for file in db.query(File).filter_by(user_id=user_id).all()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def update_file_by_id(
|
||||||
|
self, id: str, form_data: FileUpdateForm
|
||||||
|
) -> Optional[FileModel]:
|
||||||
|
with get_db() as db:
|
||||||
|
try:
|
||||||
|
file = db.query(File).filter_by(id=id).first()
|
||||||
|
|
||||||
|
if form_data.hash is not None:
|
||||||
|
file.hash = form_data.hash
|
||||||
|
|
||||||
|
if form_data.data is not None:
|
||||||
|
file.data = {**(file.data if file.data else {}), **form_data.data}
|
||||||
|
|
||||||
|
if form_data.meta is not None:
|
||||||
|
file.meta = {**(file.meta if file.meta else {}), **form_data.meta}
|
||||||
|
|
||||||
|
file.updated_at = int(time.time())
|
||||||
|
db.commit()
|
||||||
|
return FileModel.model_validate(file)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error updating file completely by id: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]:
|
def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSpl
|
||||||
from langchain_text_splitters import MarkdownHeaderTextSplitter
|
from langchain_text_splitters import MarkdownHeaderTextSplitter
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
from open_webui.models.files import FileModel, Files
|
from open_webui.models.files import FileModel, FileUpdateForm, Files
|
||||||
from open_webui.models.knowledge import Knowledges
|
from open_webui.models.knowledge import Knowledges
|
||||||
from open_webui.storage.provider import Storage
|
from open_webui.storage.provider import Storage
|
||||||
|
|
||||||
|
|
@ -2452,16 +2452,19 @@ def process_files_batch(
|
||||||
"""
|
"""
|
||||||
Process a batch of files and save them to the vector database.
|
Process a batch of files and save them to the vector database.
|
||||||
"""
|
"""
|
||||||
results: List[BatchProcessFilesResult] = []
|
|
||||||
errors: List[BatchProcessFilesResult] = []
|
|
||||||
collection_name = form_data.collection_name
|
collection_name = form_data.collection_name
|
||||||
|
|
||||||
|
file_results: List[BatchProcessFilesResult] = []
|
||||||
|
file_errors: List[BatchProcessFilesResult] = []
|
||||||
|
file_updates: List[FileUpdateForm] = []
|
||||||
|
|
||||||
# Prepare all documents first
|
# Prepare all documents first
|
||||||
all_docs: List[Document] = []
|
all_docs: List[Document] = []
|
||||||
|
|
||||||
for file in form_data.files:
|
for file in form_data.files:
|
||||||
try:
|
try:
|
||||||
text_content = file.data.get("content", "")
|
text_content = file.data.get("content", "")
|
||||||
|
|
||||||
docs: List[Document] = [
|
docs: List[Document] = [
|
||||||
Document(
|
Document(
|
||||||
page_content=text_content.replace("<br/>", "\n"),
|
page_content=text_content.replace("<br/>", "\n"),
|
||||||
|
|
@ -2475,16 +2478,22 @@ def process_files_batch(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
hash = calculate_sha256_string(text_content)
|
|
||||||
Files.update_file_hash_by_id(file.id, hash)
|
|
||||||
Files.update_file_data_by_id(file.id, {"content": text_content})
|
|
||||||
|
|
||||||
all_docs.extend(docs)
|
all_docs.extend(docs)
|
||||||
results.append(BatchProcessFilesResult(file_id=file.id, status="prepared"))
|
|
||||||
|
file_updates.append(
|
||||||
|
FileUpdateForm(
|
||||||
|
id=file.id,
|
||||||
|
hash=calculate_sha256_string(text_content),
|
||||||
|
data={"content": text_content},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
file_results.append(
|
||||||
|
BatchProcessFilesResult(file_id=file.id, status="prepared")
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}")
|
log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}")
|
||||||
errors.append(
|
file_errors.append(
|
||||||
BatchProcessFilesResult(file_id=file.id, status="failed", error=str(e))
|
BatchProcessFilesResult(file_id=file.id, status="failed", error=str(e))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2500,20 +2509,18 @@ def process_files_batch(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update all files with collection name
|
# Update all files with collection name
|
||||||
for result in results:
|
for file_update, file_result in zip(file_updates, file_results):
|
||||||
Files.update_file_metadata_by_id(
|
Files.update_file_by_id(id=file_result.file_id, form_data=file_update)
|
||||||
result.file_id, {"collection_name": collection_name}
|
file_result.status = "completed"
|
||||||
)
|
|
||||||
result.status = "completed"
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(
|
log.error(
|
||||||
f"process_files_batch: Error saving documents to vector DB: {str(e)}"
|
f"process_files_batch: Error saving documents to vector DB: {str(e)}"
|
||||||
)
|
)
|
||||||
for result in results:
|
for file_result in file_results:
|
||||||
result.status = "failed"
|
file_result.status = "failed"
|
||||||
errors.append(
|
file_errors.append(
|
||||||
BatchProcessFilesResult(file_id=result.file_id, error=str(e))
|
BatchProcessFilesResult(file_id=file_result.file_id, error=str(e))
|
||||||
)
|
)
|
||||||
|
|
||||||
return BatchProcessFilesResponse(results=results, errors=errors)
|
return BatchProcessFilesResponse(results=file_results, errors=file_errors)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue