From 0eb33e8e1289c87ab1783ba9ddd9cd396e4dab94 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Tue, 16 Dec 2025 13:49:00 -0500 Subject: [PATCH] refac: logit bias handling --- backend/open_webui/routers/openai.py | 9 +++++---- backend/open_webui/utils/middleware.py | 7 ++++--- backend/open_webui/utils/misc.py | 22 ++++++++++++---------- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index a74a59ca1f..eb8a93c8e2 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -891,10 +891,11 @@ async def generate_chat_completion( del payload["max_tokens"] # Convert the modified body back to JSON - if "logit_bias" in payload: - payload["logit_bias"] = json.loads( - convert_logit_bias_input_to_json(payload["logit_bias"]) - ) + if "logit_bias" in payload and payload["logit_bias"]: + logit_bias = convert_logit_bias_input_to_json(payload["logit_bias"]) + + if logit_bias: + payload["logit_bias"] = json.loads(logit_bias) headers, cookies = await get_headers_and_cookies( request, url, key, api_config, metadata, user=user diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 6c285b9367..aafa3879ff 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -1100,9 +1100,10 @@ def apply_params_to_form_data(form_data, model): if "logit_bias" in params and params["logit_bias"] is not None: try: - form_data["logit_bias"] = json.loads( - convert_logit_bias_input_to_json(params["logit_bias"]) - ) + logit_bias = convert_logit_bias_input_to_json(params["logit_bias"]) + + if logit_bias: + form_data["logit_bias"] = json.loads(logit_bias) except Exception as e: log.exception(f"Error parsing logit_bias: {e}") diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index e0e21249f6..4501dfaf07 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -523,16 +523,18 @@ def parse_ollama_modelfile(model_text): return data -def convert_logit_bias_input_to_json(user_input): - logit_bias_pairs = user_input.split(",") - logit_bias_json = {} - for pair in logit_bias_pairs: - token, bias = pair.split(":") - token = str(token.strip()) - bias = int(bias.strip()) - bias = 100 if bias > 100 else -100 if bias < -100 else bias - logit_bias_json[token] = bias - return json.dumps(logit_bias_json) +def convert_logit_bias_input_to_json(user_input) -> Optional[str]: + if user_input: + logit_bias_pairs = user_input.split(",") + logit_bias_json = {} + for pair in logit_bias_pairs: + token, bias = pair.split(":") + token = str(token.strip()) + bias = int(bias.strip()) + bias = 100 if bias > 100 else -100 if bias < -100 else bias + logit_bias_json[token] = bias + return json.dumps(logit_bias_json) + return None def freeze(value):