mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-26 11:15:19 +00:00
refac: logit bias handling
This commit is contained in:
parent
59d6eb2bad
commit
0eb33e8e12
3 changed files with 21 additions and 17 deletions
|
|
@ -891,10 +891,11 @@ async def generate_chat_completion(
|
||||||
del payload["max_tokens"]
|
del payload["max_tokens"]
|
||||||
|
|
||||||
# Convert the modified body back to JSON
|
# Convert the modified body back to JSON
|
||||||
if "logit_bias" in payload:
|
if "logit_bias" in payload and payload["logit_bias"]:
|
||||||
payload["logit_bias"] = json.loads(
|
logit_bias = convert_logit_bias_input_to_json(payload["logit_bias"])
|
||||||
convert_logit_bias_input_to_json(payload["logit_bias"])
|
|
||||||
)
|
if logit_bias:
|
||||||
|
payload["logit_bias"] = json.loads(logit_bias)
|
||||||
|
|
||||||
headers, cookies = await get_headers_and_cookies(
|
headers, cookies = await get_headers_and_cookies(
|
||||||
request, url, key, api_config, metadata, user=user
|
request, url, key, api_config, metadata, user=user
|
||||||
|
|
|
||||||
|
|
@ -1100,9 +1100,10 @@ def apply_params_to_form_data(form_data, model):
|
||||||
|
|
||||||
if "logit_bias" in params and params["logit_bias"] is not None:
|
if "logit_bias" in params and params["logit_bias"] is not None:
|
||||||
try:
|
try:
|
||||||
form_data["logit_bias"] = json.loads(
|
logit_bias = convert_logit_bias_input_to_json(params["logit_bias"])
|
||||||
convert_logit_bias_input_to_json(params["logit_bias"])
|
|
||||||
)
|
if logit_bias:
|
||||||
|
form_data["logit_bias"] = json.loads(logit_bias)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error parsing logit_bias: {e}")
|
log.exception(f"Error parsing logit_bias: {e}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -523,16 +523,18 @@ def parse_ollama_modelfile(model_text):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def convert_logit_bias_input_to_json(user_input):
|
def convert_logit_bias_input_to_json(user_input) -> Optional[str]:
|
||||||
logit_bias_pairs = user_input.split(",")
|
if user_input:
|
||||||
logit_bias_json = {}
|
logit_bias_pairs = user_input.split(",")
|
||||||
for pair in logit_bias_pairs:
|
logit_bias_json = {}
|
||||||
token, bias = pair.split(":")
|
for pair in logit_bias_pairs:
|
||||||
token = str(token.strip())
|
token, bias = pair.split(":")
|
||||||
bias = int(bias.strip())
|
token = str(token.strip())
|
||||||
bias = 100 if bias > 100 else -100 if bias < -100 else bias
|
bias = int(bias.strip())
|
||||||
logit_bias_json[token] = bias
|
bias = 100 if bias > 100 else -100 if bias < -100 else bias
|
||||||
return json.dumps(logit_bias_json)
|
logit_bias_json[token] = bias
|
||||||
|
return json.dumps(logit_bias_json)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def freeze(value):
|
def freeze(value):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue