refac: logit bias handling

This commit is contained in:
Timothy Jaeryang Baek 2025-12-16 13:49:00 -05:00
parent 59d6eb2bad
commit 0eb33e8e12
3 changed files with 21 additions and 17 deletions

View file

@ -891,10 +891,11 @@ async def generate_chat_completion(
del payload["max_tokens"]
# Convert the modified body back to JSON
if "logit_bias" in payload:
payload["logit_bias"] = json.loads(
convert_logit_bias_input_to_json(payload["logit_bias"])
)
if "logit_bias" in payload and payload["logit_bias"]:
logit_bias = convert_logit_bias_input_to_json(payload["logit_bias"])
if logit_bias:
payload["logit_bias"] = json.loads(logit_bias)
headers, cookies = await get_headers_and_cookies(
request, url, key, api_config, metadata, user=user

View file

@ -1100,9 +1100,10 @@ def apply_params_to_form_data(form_data, model):
if "logit_bias" in params and params["logit_bias"] is not None:
try:
form_data["logit_bias"] = json.loads(
convert_logit_bias_input_to_json(params["logit_bias"])
)
logit_bias = convert_logit_bias_input_to_json(params["logit_bias"])
if logit_bias:
form_data["logit_bias"] = json.loads(logit_bias)
except Exception as e:
log.exception(f"Error parsing logit_bias: {e}")

View file

@ -523,16 +523,18 @@ def parse_ollama_modelfile(model_text):
return data
def convert_logit_bias_input_to_json(user_input):
logit_bias_pairs = user_input.split(",")
logit_bias_json = {}
for pair in logit_bias_pairs:
token, bias = pair.split(":")
token = str(token.strip())
bias = int(bias.strip())
bias = 100 if bias > 100 else -100 if bias < -100 else bias
logit_bias_json[token] = bias
return json.dumps(logit_bias_json)
def convert_logit_bias_input_to_json(user_input) -> Optional[str]:
if user_input:
logit_bias_pairs = user_input.split(",")
logit_bias_json = {}
for pair in logit_bias_pairs:
token, bias = pair.split(":")
token = str(token.strip())
bias = int(bias.strip())
bias = 100 if bias > 100 else -100 if bias < -100 else bias
logit_bias_json[token] = bias
return json.dumps(logit_bias_json)
return None
def freeze(value):