mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
feat: improve temperature logic in chat (#1749)
This PR adds support for `do_sample` to chat to enable greedy sampling --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
5d36a5e368
commit
ab59a5e346
@ -1002,6 +1002,7 @@ async fn chat_completions(
|
||||
tools,
|
||||
tool_choice,
|
||||
tool_prompt,
|
||||
temperature,
|
||||
..
|
||||
} = req;
|
||||
|
||||
@ -1010,6 +1011,11 @@ async fn chat_completions(
|
||||
let logprobs = logprobs.unwrap_or(false);
|
||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
||||
let stop = stop.unwrap_or_default();
|
||||
// enable greedy only when temperature is 0
|
||||
let (do_sample, temperature) = match temperature {
|
||||
Some(temperature) if temperature == 0.0 => (false, None),
|
||||
other => (true, other),
|
||||
};
|
||||
|
||||
// extract tool grammar if present
|
||||
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
||||
@ -1056,13 +1062,13 @@ async fn chat_completions(
|
||||
inputs: inputs.to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature: req.temperature,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
frequency_penalty: req.frequency_penalty,
|
||||
top_k: None,
|
||||
top_p: req.top_p,
|
||||
typical_p: None,
|
||||
do_sample: true,
|
||||
do_sample,
|
||||
max_new_tokens,
|
||||
return_full_text: None,
|
||||
stop,
|
||||
|
@ -274,7 +274,7 @@ class HeterogeneousNextTokenChooser:
|
||||
else None
|
||||
)
|
||||
|
||||
if any([x != 1.0 for x in temperature]):
|
||||
if any(x != 1.0 for x in temperature):
|
||||
do_sample = [
|
||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||
]
|
||||
@ -282,15 +282,15 @@ class HeterogeneousNextTokenChooser:
|
||||
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
|
||||
)
|
||||
|
||||
if any([x != 0 for x in top_k]):
|
||||
if any(x != 0 for x in top_k):
|
||||
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
|
||||
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
|
||||
|
||||
if any([x < 1.0 for x in top_p]):
|
||||
if any(x < 1.0 for x in top_p):
|
||||
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
|
||||
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
|
||||
|
||||
if any([x < 1.0 for x in typical_p]):
|
||||
if any(x < 1.0 for x in typical_p):
|
||||
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
|
||||
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user