mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: update temperature and sampling logic in chat
This commit is contained in:
parent
0520bde039
commit
27cd254b89
@ -709,10 +709,6 @@ pub(crate) struct ChatRequest {
|
|||||||
#[schema(nullable = true, example = "null")]
|
#[schema(nullable = true, example = "null")]
|
||||||
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
|
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
|
||||||
pub tool_choice: Option<ToolType>,
|
pub tool_choice: Option<ToolType>,
|
||||||
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(default = "false", example = true)]
|
|
||||||
pub do_sample: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_tool_prompt() -> Option<String> {
|
fn default_tool_prompt() -> Option<String> {
|
||||||
|
@ -1054,13 +1054,13 @@ async fn chat_completions(
|
|||||||
inputs: inputs.to_string(),
|
inputs: inputs.to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature: req.temperature,
|
temperature: req.temperature.map(|t| 1.0 - t),
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
frequency_penalty: req.frequency_penalty,
|
frequency_penalty: req.frequency_penalty,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: req.top_p,
|
top_p: req.top_p,
|
||||||
typical_p: None,
|
typical_p: None,
|
||||||
do_sample: req.do_sample,
|
do_sample: req.temperature.map_or(true, |t| t > 0.0),
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
return_full_text: None,
|
return_full_text: None,
|
||||||
stop,
|
stop,
|
||||||
|
@ -217,7 +217,7 @@ impl Validation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let temperature = temperature.unwrap_or(1.0);
|
let temperature = temperature.unwrap_or(1.0);
|
||||||
if temperature <= 0.0 {
|
if temperature < 0.0 {
|
||||||
return Err(ValidationError::Temperature);
|
return Err(ValidationError::Temperature);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -275,7 +275,9 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
if any([x != 1.0 for x in temperature]):
|
if any([x != 1.0 for x in temperature]):
|
||||||
do_sample = [
|
do_sample = [
|
||||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
# 1 and 0 both mean no sampling in different contexts
|
||||||
|
sample or x == 1.0 or x == 0.0 or math.isclose(x, 0.0)
|
||||||
|
for x, sample in zip(temperature, do_sample)
|
||||||
]
|
]
|
||||||
warpers.append(
|
warpers.append(
|
||||||
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
|
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
|
||||||
|
Loading…
Reference in New Issue
Block a user