mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-24 02:05:23 +00:00
refac: logit bias handling
This commit is contained in:
parent
59d6eb2bad
commit
0eb33e8e12
3 changed files with 21 additions and 17 deletions
|
|
@ -891,10 +891,11 @@ async def generate_chat_completion(
|
|||
del payload["max_tokens"]
|
||||
|
||||
# Convert the modified body back to JSON
|
||||
if "logit_bias" in payload:
|
||||
payload["logit_bias"] = json.loads(
|
||||
convert_logit_bias_input_to_json(payload["logit_bias"])
|
||||
)
|
||||
if "logit_bias" in payload and payload["logit_bias"]:
|
||||
logit_bias = convert_logit_bias_input_to_json(payload["logit_bias"])
|
||||
|
||||
if logit_bias:
|
||||
payload["logit_bias"] = json.loads(logit_bias)
|
||||
|
||||
headers, cookies = await get_headers_and_cookies(
|
||||
request, url, key, api_config, metadata, user=user
|
||||
|
|
|
|||
|
|
@ -1100,9 +1100,10 @@ def apply_params_to_form_data(form_data, model):
|
|||
|
||||
if "logit_bias" in params and params["logit_bias"] is not None:
|
||||
try:
|
||||
form_data["logit_bias"] = json.loads(
|
||||
convert_logit_bias_input_to_json(params["logit_bias"])
|
||||
)
|
||||
logit_bias = convert_logit_bias_input_to_json(params["logit_bias"])
|
||||
|
||||
if logit_bias:
|
||||
form_data["logit_bias"] = json.loads(logit_bias)
|
||||
except Exception as e:
|
||||
log.exception(f"Error parsing logit_bias: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -523,16 +523,18 @@ def parse_ollama_modelfile(model_text):
|
|||
return data
|
||||
|
||||
|
||||
def convert_logit_bias_input_to_json(user_input):
|
||||
logit_bias_pairs = user_input.split(",")
|
||||
logit_bias_json = {}
|
||||
for pair in logit_bias_pairs:
|
||||
token, bias = pair.split(":")
|
||||
token = str(token.strip())
|
||||
bias = int(bias.strip())
|
||||
bias = 100 if bias > 100 else -100 if bias < -100 else bias
|
||||
logit_bias_json[token] = bias
|
||||
return json.dumps(logit_bias_json)
|
||||
def convert_logit_bias_input_to_json(user_input) -> Optional[str]:
|
||||
if user_input:
|
||||
logit_bias_pairs = user_input.split(",")
|
||||
logit_bias_json = {}
|
||||
for pair in logit_bias_pairs:
|
||||
token, bias = pair.split(":")
|
||||
token = str(token.strip())
|
||||
bias = int(bias.strip())
|
||||
bias = 100 if bias > 100 else -100 if bias < -100 else bias
|
||||
logit_bias_json[token] = bias
|
||||
return json.dumps(logit_bias_json)
|
||||
return None
|
||||
|
||||
|
||||
def freeze(value):
|
||||
|
|
|
|||
Loading…
Reference in a new issue