diff --git a/router/src/lib.rs b/router/src/lib.rs index 17ab00e7..2395e3e2 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -709,10 +709,6 @@ pub(crate) struct ChatRequest { #[schema(nullable = true, example = "null")] #[serde(deserialize_with = "deserialize_tool_choice::deserialize")] pub tool_choice: Option, - - #[serde(default)] - #[schema(default = "false", example = true)] - pub do_sample: bool, } fn default_tool_prompt() -> Option { diff --git a/router/src/server.rs b/router/src/server.rs index e728455d..68f76260 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1054,13 +1054,13 @@ async fn chat_completions( inputs: inputs.to_string(), parameters: GenerateParameters { best_of: None, - temperature: req.temperature, + temperature: req.temperature.map(|t| 1.0 - t), repetition_penalty, frequency_penalty: req.frequency_penalty, top_k: None, top_p: req.top_p, typical_p: None, - do_sample: req.do_sample, + do_sample: req.temperature.map_or(true, |t| t > 0.0), max_new_tokens, return_full_text: None, stop, diff --git a/router/src/validation.rs b/router/src/validation.rs index 2029c7e0..926625c2 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -217,7 +217,7 @@ impl Validation { } let temperature = temperature.unwrap_or(1.0); - if temperature <= 0.0 { + if temperature < 0.0 { return Err(ValidationError::Temperature); } diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 7c8a18f0..0a1d8a51 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -275,7 +275,9 @@ class HeterogeneousNextTokenChooser: if any([x != 1.0 for x in temperature]): do_sample = [ - sample or x != 1.0 for x, sample in zip(temperature, do_sample) + # 1 and 0 both mean no sampling in different contexts + sample or x == 1.0 or x == 0.0 or math.isclose(x, 0.0) + for x, sample in zip(temperature, do_sample) ] warpers.append( HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)