diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index cbaefe1f3e..d24bd5dcf1 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1437,11 +1437,15 @@ async def chat_completion( stream_delta_chunk_size = form_data.get("params", {}).get( "stream_delta_chunk_size" ) + reasoning_tags = form_data.get("params", {}).get("reasoning_tags") # Model Params if model_info_params.get("stream_delta_chunk_size"): stream_delta_chunk_size = model_info_params.get("stream_delta_chunk_size") + if model_info_params.get("reasoning_tags") is not None: + reasoning_tags = model_info_params.get("reasoning_tags") + metadata = { "user_id": user.id, "chat_id": form_data.pop("chat_id", None), @@ -1457,6 +1461,7 @@ async def chat_completion( "direct": model_item.get("direct", False), "params": { "stream_delta_chunk_size": stream_delta_chunk_size, + "reasoning_tags": reasoning_tags, "function_calling": ( "native" if ( diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 66d4ad6286..41e56e6530 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -111,6 +111,20 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) +DEFAULT_REASONING_TAGS = [ + ("", ""), + ("", ""), + ("", ""), + ("", ""), + ("", ""), + ("", ""), + ("<|begin_of_thought|>", "<|end_of_thought|>"), + ("◁think▷", "◁/think▷"), +] +DEFAULT_SOLUTION_TAGS = [("<|begin_of_solution|>", "<|end_of_solution|>")] +DEFAULT_CODE_INTERPRETER_TAGS = [("", "")] + + async def chat_completion_tools_handler( request: Request, body: dict, extra_params: dict, user: UserModel, models, tools ) -> tuple[dict, dict]: @@ -694,6 +708,7 @@ def apply_params_to_form_data(form_data, model): "stream_response": bool, "stream_delta_chunk_size": int, "function_calling": str, + "reasoning_tags": list, "system": str, } @@ -1811,27 +1826,23 @@ async def process_chat_response( } ] - # We might want to disable this by default - DETECT_REASONING = True - DETECT_SOLUTION = True + reasoning_tags_param = metadata.get("params", {}).get("reasoning_tags") + DETECT_REASONING_TAGS = reasoning_tags_param is not False DETECT_CODE_INTERPRETER = metadata.get("features", {}).get( "code_interpreter", False ) - reasoning_tags = [ - ("", ""), - ("", ""), - ("", ""), - ("", ""), - ("", ""), - ("", ""), - ("<|begin_of_thought|>", "<|end_of_thought|>"), - ("◁think▷", "◁/think▷"), - ] - - code_interpreter_tags = [("", "")] - - solution_tags = [("<|begin_of_solution|>", "<|end_of_solution|>")] + reasoning_tags = [] + if DETECT_REASONING_TAGS: + if ( + isinstance(reasoning_tags_param, list) + and len(reasoning_tags_param) == 2 + ): + reasoning_tags = [ + (reasoning_tags_param[0], reasoning_tags_param[1]) + ] + else: + reasoning_tags = DEFAULT_REASONING_TAGS try: for event in events: @@ -2083,7 +2094,7 @@ async def process_chat_response( content_blocks[-1]["content"] + value ) - if DETECT_REASONING: + if DETECT_REASONING_TAGS: content, content_blocks, _ = ( tag_content_handler( "reasoning", @@ -2093,11 +2104,20 @@ async def process_chat_response( ) ) + content, content_blocks, _ = ( + tag_content_handler( + "solution", + DEFAULT_SOLUTION_TAGS, + content, + content_blocks, + ) + ) + if DETECT_CODE_INTERPRETER: content, content_blocks, end = ( tag_content_handler( "code_interpreter", - code_interpreter_tags, + DEFAULT_CODE_INTERPRETER_TAGS, content, content_blocks, ) @@ -2106,16 +2126,6 @@ async def process_chat_response( if end: break - if DETECT_SOLUTION: - content, content_blocks, _ = ( - tag_content_handler( - "solution", - solution_tags, - content, - content_blocks, - ) - ) - if ENABLE_REALTIME_CHAT_SAVE: # Save message in the database Chats.upsert_message_to_chat_by_id_and_message_id( diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index 811ba75c9f..39c785854a 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -63,6 +63,7 @@ def remove_open_webui_params(params: dict) -> dict: "stream_response": bool, "stream_delta_chunk_size": int, "function_calling": str, + "reasoning_tags": list, "system": str, } diff --git a/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte b/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte index b900046a92..4bc0647a76 100644 --- a/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte +++ b/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte @@ -17,6 +17,7 @@ stream_response: null, // Set stream responses for this model individually stream_delta_chunk_size: null, // Set the chunk size for streaming responses function_calling: null, + reasoning_tags: null, seed: null, stop: null, temperature: null, @@ -175,6 +176,69 @@ +
+ +
+
+ {$i18n.t('Reasoning Tags')} +
+ +
+
+ + {#if ![true, false, null].includes(params?.reasoning_tags ?? null) && (params?.reasoning_tags ?? []).length === 2} +
+
+ +
+ +
+ +
+
+ {/if} +
+