fix: reduce and refactor changes

This commit is contained in:
drbh 2024-04-15 19:39:48 +00:00
parent 27cd254b89
commit 24a5588735
2 changed files with 12 additions and 9 deletions

View File

@ -1000,6 +1000,7 @@ async fn chat_completions(
tools, tools,
tool_choice, tool_choice,
tool_prompt, tool_prompt,
temperature
.. ..
} = req; } = req;
@ -1008,6 +1009,10 @@ async fn chat_completions(
let logprobs = logprobs.unwrap_or(false); let logprobs = logprobs.unwrap_or(false);
let tool_prompt = tool_prompt.unwrap_or_default(); let tool_prompt = tool_prompt.unwrap_or_default();
let stop = stop.unwrap_or_default(); let stop = stop.unwrap_or_default();
// rescale where 0 is deterministic and 1 is random (this is the opposite of other endpoints)
let adjusted_temperature = temperature.map_or(1.0, |t| 1.0 - t);
let do_sample = adjusted_temperature > 0.0;
let temperature = Some(adjusted_temperature);
// extract tool grammar if present // extract tool grammar if present
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
@ -1054,13 +1059,13 @@ async fn chat_completions(
inputs: inputs.to_string(), inputs: inputs.to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature: req.temperature.map(|t| 1.0 - t), temperature,
repetition_penalty, repetition_penalty,
frequency_penalty: req.frequency_penalty, frequency_penalty: req.frequency_penalty,
top_k: None, top_k: None,
top_p: req.top_p, top_p: req.top_p,
typical_p: None, typical_p: None,
do_sample: req.temperature.map_or(true, |t| t > 0.0), do_sample,
max_new_tokens, max_new_tokens,
return_full_text: None, return_full_text: None,
stop, stop,

View File

@ -273,25 +273,23 @@ class HeterogeneousNextTokenChooser:
else None else None
) )
if any([x != 1.0 for x in temperature]): if any(x != 1.0 for x in temperature):
do_sample = [ do_sample = [
# 1 and 0 both mean no sampling in different contexts sample or x != 1.0 for x, sample in zip(temperature, do_sample)
sample or x == 1.0 or x == 0.0 or math.isclose(x, 0.0)
for x, sample in zip(temperature, do_sample)
] ]
warpers.append( warpers.append(
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device) HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
) )
if any([x != 0 for x in top_k]): if any(x != 0 for x in top_k):
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)] do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device)) warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
if any([x < 1.0 for x in top_p]): if any(x < 1.0 for x in top_p):
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)] do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device)) warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
if any([x < 1.0 for x in typical_p]): if any(x < 1.0 for x in typical_p):
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)] do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device)) warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))