From 24a5588735a8d646aee6354ba5fd6a68aa8149c7 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 15 Apr 2024 19:39:48 +0000 Subject: [PATCH] fix: reduce and refactor changes --- router/src/server.rs | 9 +++++++-- server/text_generation_server/utils/tokens.py | 12 +++++------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 68f76260..b7273a18 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1000,6 +1000,7 @@ async fn chat_completions( tools, tool_choice, tool_prompt, + temperature .. } = req; @@ -1008,6 +1009,10 @@ async fn chat_completions( let logprobs = logprobs.unwrap_or(false); let tool_prompt = tool_prompt.unwrap_or_default(); let stop = stop.unwrap_or_default(); + // rescale where 0 is deterministic and 1 is random (this is the opposite of other endpoints) + let adjusted_temperature = temperature.map_or(1.0, |t| 1.0 - t); + let do_sample = adjusted_temperature > 0.0; + let temperature = Some(adjusted_temperature); // extract tool grammar if present let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { @@ -1054,13 +1059,13 @@ async fn chat_completions( inputs: inputs.to_string(), parameters: GenerateParameters { best_of: None, - temperature: req.temperature.map(|t| 1.0 - t), + temperature, repetition_penalty, frequency_penalty: req.frequency_penalty, top_k: None, top_p: req.top_p, typical_p: None, - do_sample: req.temperature.map_or(true, |t| t > 0.0), + do_sample, max_new_tokens, return_full_text: None, stop, diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 0a1d8a51..8ef1ca0d 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -273,25 +273,23 @@ class HeterogeneousNextTokenChooser: else None ) - if any([x != 1.0 for x in temperature]): + if any(x != 1.0 for x in temperature): do_sample = [ - # 1 and 0 both mean no sampling in different contexts - sample or x == 1.0 or x == 0.0 or math.isclose(x, 0.0) - for x, sample in zip(temperature, do_sample) + sample or x != 1.0 for x, sample in zip(temperature, do_sample) ] warpers.append( HeterogeneousTemperatureLogitsWarper(temperature, dtype, device) ) - if any([x != 0 for x in top_k]): + if any(x != 0 for x in top_k): do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)] warpers.append(HeterogeneousTopKLogitsWarper(top_k, device)) - if any([x < 1.0 for x in top_p]): + if any(x < 1.0 for x in top_p): do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)] warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device)) - if any([x < 1.0 for x in typical_p]): + if any(x < 1.0 for x in typical_p): do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)] warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))