diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 32c3316541..32346d3b96 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -20,7 +20,16 @@ jobs: - name: Build and run Compose Stack run: | - docker compose up --detach --build + docker compose --file docker-compose.yaml --file docker-compose.api.yaml up --detach --build + + - name: Wait for Ollama to be up + timeout-minutes: 5 + run: | + until curl --output /dev/null --silent --fail http://localhost:11434; do + printf '.' + sleep 1 + done + echo "Service is up!" - name: Preload Ollama model run: | diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 97cf9977ba..a57a126c8d 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -80,6 +80,7 @@ from config import ( RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, ENABLE_RAG_HYBRID_SEARCH, + ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, RAG_RERANKING_MODEL, PDF_EXTRACT_IMAGES, RAG_RERANKING_MODEL_AUTO_UPDATE, @@ -91,7 +92,7 @@ from config import ( CHUNK_SIZE, CHUNK_OVERLAP, RAG_TEMPLATE, - ENABLE_LOCAL_WEB_FETCH, + ENABLE_RAG_LOCAL_WEB_FETCH, ) from constants import ERROR_MESSAGES @@ -105,6 +106,9 @@ app.state.TOP_K = RAG_TOP_K app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH +app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( + ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION +) app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP @@ -114,6 +118,7 @@ app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL app.state.RAG_TEMPLATE = RAG_TEMPLATE + app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY @@ -313,6 +318,7 @@ async def get_rag_config(user=Depends(get_admin_user)): "chunk_size": app.state.CHUNK_SIZE, "chunk_overlap": app.state.CHUNK_OVERLAP, }, + "web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, } @@ -322,15 +328,34 @@ class ChunkParamUpdateForm(BaseModel): class ConfigUpdateForm(BaseModel): - pdf_extract_images: bool - chunk: ChunkParamUpdateForm + pdf_extract_images: Optional[bool] = None + chunk: Optional[ChunkParamUpdateForm] = None + web_loader_ssl_verification: Optional[bool] = None @app.post("/config/update") async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): - app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images - app.state.CHUNK_SIZE = form_data.chunk.chunk_size - app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap + app.state.PDF_EXTRACT_IMAGES = ( + form_data.pdf_extract_images + if form_data.pdf_extract_images != None + else app.state.PDF_EXTRACT_IMAGES + ) + + app.state.CHUNK_SIZE = ( + form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE + ) + + app.state.CHUNK_OVERLAP = ( + form_data.chunk.chunk_overlap + if form_data.chunk != None + else app.state.CHUNK_OVERLAP + ) + + app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( + form_data.web_loader_ssl_verification + if form_data.web_loader_ssl_verification != None + else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION + ) return { "status": True, @@ -339,6 +364,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ "chunk_size": app.state.CHUNK_SIZE, "chunk_overlap": app.state.CHUNK_OVERLAP, }, + "web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, } @@ -490,7 +516,9 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)): def store_web(form_data: UrlForm, user=Depends(get_current_user)): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" try: - loader = get_web_loader(form_data.url) + loader = get_web_loader( + form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION + ) data = loader.load() collection_name = form_data.collection_name @@ -510,12 +538,11 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)): detail=ERROR_MESSAGES.DEFAULT(e), ) - -def get_web_loader(url: Union[str, Sequence[str]]): +def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True): # Check if the URL is valid if not validate_url(url): raise ValueError(ERROR_MESSAGES.INVALID_URL) - return WebBaseLoader(url) + return WebBaseLoader(url, verify_ssl=verify_ssl) def validate_url(url: Union[str, Sequence[str]]): diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 96401b277e..62191481fd 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -287,14 +287,14 @@ def rag_messages( for doc in docs: context = None - collection = doc.get("collection_name") - if collection: - collection = [collection] - else: - collection = doc.get("collection_names", []) + collection_names = ( + doc["collection_names"] + if doc["type"] == "collection" + else [doc["collection_name"]] + ) - collection = set(collection).difference(extracted_collections) - if not collection: + collection_names = set(collection_names).difference(extracted_collections) + if not collection_names: log.debug(f"skipping {doc} as it has already been extracted") continue @@ -304,11 +304,7 @@ def rag_messages( else: if hybrid_search: context = query_collection_with_hybrid_search( - collection_names=( - doc["collection_names"] - if doc["type"] == "collection" - else [doc["collection_name"]] - ), + collection_names=collection_names, query=query, embedding_function=embedding_function, k=k, @@ -317,11 +313,7 @@ def rag_messages( ) else: context = query_collection( - collection_names=( - doc["collection_names"] - if doc["type"] == "collection" - else [doc["collection_name"]] - ), + collection_names=collection_names, query=query, embedding_function=embedding_function, k=k, @@ -331,18 +323,31 @@ def rag_messages( context = None if context: - relevant_contexts.append(context) + relevant_contexts.append({**context, "source": doc}) - extracted_collections.extend(collection) + extracted_collections.extend(collection_names) context_string = "" + + citations = [] for context in relevant_contexts: try: if "documents" in context: - items = [item for item in context["documents"][0] if item is not None] - context_string += "\n\n".join(items) + context_string += "\n\n".join( + [text for text in context["documents"][0] if text is not None] + ) + + if "metadatas" in context: + citations.append( + { + "source": context["source"], + "document": context["documents"][0], + "metadata": context["metadatas"][0], + } + ) except Exception as e: log.exception(e) + context_string = context_string.strip() ra_content = rag_template( @@ -371,7 +376,7 @@ def rag_messages( messages[last_user_message_idx] = new_user_message - return messages + return messages, citations def get_model_path(model: str, update_model: bool = False): diff --git a/backend/config.py b/backend/config.py index 10d4496a51..24e124ea92 100644 --- a/backend/config.py +++ b/backend/config.py @@ -18,6 +18,18 @@ from secrets import token_bytes from constants import ERROR_MESSAGES +#################################### +# Load .env file +#################################### + +try: + from dotenv import load_dotenv, find_dotenv + + load_dotenv(find_dotenv("../.env")) +except ImportError: + print("dotenv not installed, skipping...") + + #################################### # LOGGING #################################### @@ -59,16 +71,6 @@ for source in log_sources: log.setLevel(SRC_LOG_LEVELS["CONFIG"]) -#################################### -# Load .env file -#################################### - -try: - from dotenv import load_dotenv, find_dotenv - - load_dotenv(find_dotenv("../.env")) -except ImportError: - log.warning("dotenv not installed, skipping...") WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") if WEBUI_NAME != "Open WebUI": @@ -454,6 +456,11 @@ ENABLE_RAG_HYBRID_SEARCH = ( os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true" ) + +ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( + os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true" +) + RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true" @@ -531,7 +538,9 @@ RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE) RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY) -ENABLE_LOCAL_WEB_FETCH = os.getenv("ENABLE_LOCAL_WEB_FETCH", "False").lower() == "true" +ENABLE_RAG_LOCAL_WEB_FETCH = ( + os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true" +) SEARXNG_QUERY_URL = os.getenv("SEARXNG_QUERY_URL", "") GOOGLE_PSE_API_KEY = os.getenv("GOOGLE_PSE_API_KEY", "") diff --git a/backend/main.py b/backend/main.py index e36b8296af..dc2175ad55 100644 --- a/backend/main.py +++ b/backend/main.py @@ -15,7 +15,7 @@ from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.cors import CORSMiddleware from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware - +from starlette.responses import StreamingResponse from apps.ollama.main import app as ollama_app from apps.openai.main import app as openai_app @@ -102,6 +102,8 @@ origins = ["*"] class RAGMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): + return_citations = False + if request.method == "POST" and ( "/api/chat" in request.url.path or "/chat/completions" in request.url.path ): @@ -114,11 +116,15 @@ class RAGMiddleware(BaseHTTPMiddleware): # Parse string to JSON data = json.loads(body_str) if body_str else {} + return_citations = data.get("citations", False) + if "citations" in data: + del data["citations"] + # Example: Add a new key-value pair or modify existing ones # data["modified"] = True # Example modification if "docs" in data: data = {**data} - data["messages"] = rag_messages( + data["messages"], citations = rag_messages( docs=data["docs"], messages=data["messages"], template=rag_app.state.RAG_TEMPLATE, @@ -130,7 +136,9 @@ class RAGMiddleware(BaseHTTPMiddleware): ) del data["docs"] - log.debug(f"data['messages']: {data['messages']}") + log.debug( + f"data['messages']: {data['messages']}, citations: {citations}" + ) modified_body_bytes = json.dumps(data).encode("utf-8") @@ -148,11 +156,36 @@ class RAGMiddleware(BaseHTTPMiddleware): ] response = await call_next(request) + + if return_citations: + # Inject the citations into the response + if isinstance(response, StreamingResponse): + # If it's a streaming response, inject it as SSE event or NDJSON line + content_type = response.headers.get("Content-Type") + if "text/event-stream" in content_type: + return StreamingResponse( + self.openai_stream_wrapper(response.body_iterator, citations), + ) + if "application/x-ndjson" in content_type: + return StreamingResponse( + self.ollama_stream_wrapper(response.body_iterator, citations), + ) + return response async def _receive(self, body: bytes): return {"type": "http.request", "body": body, "more_body": False} + async def openai_stream_wrapper(self, original_generator, citations): + yield f"data: {json.dumps({'citations': citations})}\n\n" + async for data in original_generator: + yield data + + async def ollama_stream_wrapper(self, original_generator, citations): + yield f"{json.dumps({'citations': citations})}\n" + async for data in original_generator: + yield data + app.add_middleware(RAGMiddleware) diff --git a/src/app.css b/src/app.css index dccc6e092c..7d111bd48b 100644 --- a/src/app.css +++ b/src/app.css @@ -82,3 +82,12 @@ select { .katex-mathml { display: none; } + +.scrollbar-none:active::-webkit-scrollbar-thumb, +.scrollbar-none:focus::-webkit-scrollbar-thumb, +.scrollbar-none:hover::-webkit-scrollbar-thumb { + visibility: visible; +} +.scrollbar-none::-webkit-scrollbar-thumb { + visibility: hidden; +} diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index a94aceaced..b28d060718 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -159,7 +159,11 @@ export const generateTitle = async ( body: JSON.stringify({ model: model, prompt: template, - stream: false + stream: false, + options: { + // Restrict the number of tokens generated to 50 + num_predict: 50 + } }) }) .then(async (res) => { diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index ac770e5b7d..944e2d40dc 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -295,7 +295,9 @@ export const generateTitle = async ( content: template } ], - stream: false + stream: false, + // Restricting the max tokens to 50 to avoid long titles + max_tokens: 50 }) }) .then(async (res) => { diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index a9d163f874..ccf166dabb 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -33,8 +33,9 @@ type ChunkConfigForm = { }; type RAGConfigForm = { - pdf_extract_images: boolean; - chunk: ChunkConfigForm; + pdf_extract_images?: boolean; + chunk?: ChunkConfigForm; + web_loader_ssl_verification?: boolean; }; export const updateRAGConfig = async (token: string, payload: RAGConfigForm) => { diff --git a/src/lib/apis/streaming/index.ts b/src/lib/apis/streaming/index.ts index a72dbe47da..0e87c2524a 100644 --- a/src/lib/apis/streaming/index.ts +++ b/src/lib/apis/streaming/index.ts @@ -4,6 +4,8 @@ import type { ParsedEvent } from 'eventsource-parser'; type TextStreamUpdate = { done: boolean; value: string; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + citations?: any; }; // createOpenAITextStream takes a responseBody with a SSE response, @@ -45,6 +47,11 @@ async function* openAIStreamToIterator( const parsedData = JSON.parse(data); console.log(parsedData); + if (parsedData.citations) { + yield { done: false, value: '', citations: parsedData.citations }; + continue; + } + yield { done: false, value: parsedData.choices?.[0]?.delta?.content ?? '' }; } catch (e) { console.error('Error extracting delta from SSE event:', e); @@ -62,6 +69,10 @@ async function* streamLargeDeltasAsRandomChunks( yield textStreamUpdate; return; } + if (textStreamUpdate.citations) { + yield textStreamUpdate; + continue; + } let content = textStreamUpdate.value; if (content.length < 5) { yield { done: false, value: content }; diff --git a/src/lib/components/chat/Messages/CitationsModal.svelte b/src/lib/components/chat/Messages/CitationsModal.svelte new file mode 100644 index 0000000000..c7db034be7 --- /dev/null +++ b/src/lib/components/chat/Messages/CitationsModal.svelte @@ -0,0 +1,77 @@ + + + +
+
+
+ {$i18n.t('Citation')} +
+ +
+ +
+
+ {#each mergedDocuments as document, documentIdx} +
+
+ {$i18n.t('Source')} +
+
+ {document.source?.name ?? $i18n.t('No source available')} +
+
+
+
+ {$i18n.t('Content')} +
+
+							{document.document}
+						
+
+ + {#if documentIdx !== mergedDocuments.length - 1} +
+ {/if} + {/each} +
+
+
+
diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte index 4d87f929f3..6d027ab292 100644 --- a/src/lib/components/chat/Messages/ResponseMessage.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage.svelte @@ -32,6 +32,7 @@ import { WEBUI_BASE_URL } from '$lib/constants'; import Tooltip from '$lib/components/common/Tooltip.svelte'; import RateComment from './RateComment.svelte'; + import CitationsModal from '$lib/components/chat/Messages/CitationsModal.svelte'; export let modelfiles = []; export let message; @@ -65,6 +66,9 @@ let showRateComment = false; + let showCitationModal = false; + let selectedCitation = null; + $: tokens = marked.lexer(sanitizeResponseContent(message.content)); const renderer = new marked.Renderer(); @@ -324,6 +328,8 @@ }); + + {#key message.id}
{/if} -
@@ -441,436 +446,484 @@ {/each} {/if} - - {#if message.done} -
- {#if siblings.length > 1} -
- - -
- {siblings.indexOf(message.id) + 1} / {siblings.length} -
- - -
- {/if} - - {#if !readOnly} - - - - {/if} - - - - - - {#if !readOnly} - - - - - - - - {/if} - - - - - - {#if $config.images && !readOnly} - - - - {/if} - - {#if message.info} - - - - {/if} - - {#if isLastMessage && !readOnly} - - - - - - - - {/if} -
- {/if} - - {#if showRateComment} - { - updateChatMessages(); - }} - /> - {/if}
{/if}
+ + + + {#if message.citations} +
+
+ {#each message.citations.reduce((acc, citation) => { + citation.document.forEach((document, index) => { + const metadata = citation.metadata?.[index]; + const id = metadata?.source ?? 'N/A'; + + const existingSource = acc.find((item) => item.id === id); + + if (existingSource) { + existingSource.document.push(document); + existingSource.metadata.push(metadata); + } else { + acc.push( { id: id, source: citation?.source, document: [document], metadata: metadata ? [metadata] : [] } ); + } + }); + return acc; + }, []) as citation, idx} +
+
+ [{idx + 1}] +
+ + +
+ {/each} +
+ {/if} + + {#if message.done} +
+ {#if siblings.length > 1} +
+ + +
+ {siblings.indexOf(message.id) + 1} / {siblings.length} +
+ + +
+ {/if} + + {#if !readOnly} + + + + {/if} + + + + + + {#if !readOnly} + + + + + + + + {/if} + + + + + + {#if $config.images && !readOnly} + + + + {/if} + + {#if message.info} + + + + {/if} + + {#if isLastMessage && !readOnly} + + + + + + + + {/if} +
+ {/if} + + {#if showRateComment} + { + updateChatMessages(); + }} + /> + {/if} {/if} diff --git a/src/lib/components/chat/ModelSelector.svelte b/src/lib/components/chat/ModelSelector.svelte index 2874da49b8..86819afca2 100644 --- a/src/lib/components/chat/ModelSelector.svelte +++ b/src/lib/components/chat/ModelSelector.svelte @@ -82,7 +82,7 @@ {:else}
- + {/if} diff --git a/src/lib/components/chat/Settings/Account.svelte b/src/lib/components/chat/Settings/Account.svelte index 81652c0dcb..d27344bdba 100644 --- a/src/lib/components/chat/Settings/Account.svelte +++ b/src/lib/components/chat/Settings/Account.svelte @@ -447,7 +447,7 @@ {/if} - + {/if}
diff --git a/src/lib/components/chat/Settings/Connections.svelte b/src/lib/components/chat/Settings/Connections.svelte index 87825f2a5b..3ecfa0d74d 100644 --- a/src/lib/components/chat/Settings/Connections.svelte +++ b/src/lib/components/chat/Settings/Connections.svelte @@ -164,7 +164,7 @@
diff --git a/src/lib/components/chat/ShareChatModal.svelte b/src/lib/components/chat/ShareChatModal.svelte index 10145b08e0..50f38d8a9b 100644 --- a/src/lib/components/chat/ShareChatModal.svelte +++ b/src/lib/components/chat/ShareChatModal.svelte @@ -97,9 +97,10 @@
{#if chat.share_id} You have shared this chat before.{$i18n.t('You have shared this chat')} + {$i18n.t('before')}. - Click here to + {$i18n.t('Click here to')} and create a new shared link. + }} + >{$i18n.t('delete this link')} + + {$i18n.t('and create a new shared link.')} {:else} - Messages you send after creating your link won't be shared. Users with the URL will be - able to view the shared chat. + {$i18n.t( + "Messages you send after creating your link won't be shared. Users with the URL will beable to view the shared chat." + )} {/if}
diff --git a/src/lib/components/common/ImagePreview.svelte b/src/lib/components/common/ImagePreview.svelte index badabebda4..99882cbca4 100644 --- a/src/lib/components/common/ImagePreview.svelte +++ b/src/lib/components/common/ImagePreview.svelte @@ -51,7 +51,7 @@ +
+ + + +
+ +
+ diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte index 16982b6e25..eaa7bbdff9 100644 --- a/src/lib/components/documents/Settings/General.svelte +++ b/src/lib/components/documents/Settings/General.svelte @@ -1,8 +1,6 @@ + +
{ + submitHandler(); + saveHandler(); + }} +> +
+
+
{$i18n.t('Query Params')}
+ +
+
+
{$i18n.t('Top K')}
+ +
+ +
+
+ + {#if querySettings.hybrid === true} +
+
+ {$i18n.t('Minimum Score')} +
+ +
+ +
+
+ {/if} +
+ + {#if querySettings.hybrid === true} +
+ {$i18n.t( + 'Note: If you set a minimum score, the search will only return documents with a score greater than or equal to the minimum score.' + )} +
+ +
+ {/if} + +
+
{$i18n.t('RAG Template')}
+