From b45f6484833b20c4efed21c2f8c0c4289cd4f67f Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Mon, 18 Mar 2024 15:17:47 +0100 Subject: [PATCH] Add warmup for logits processors (#107) Co-authored-by: Karol Damaszke --- router/client/src/client.rs | 14 +++++++------- server/text_generation_server/models/causal_lm.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 2bff468c..9aefaf55 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -165,14 +165,14 @@ impl Client { inputs: self.get_random_input(input_length, seq_bucket_size), truncate: max_input_length, parameters: Some(NextTokenChooserParameters { - temperature: 1.0, - top_k: 0, - top_p: 1.0, - typical_p: 1.0, - do_sample: false, + temperature: 0.9, + top_k: 10, + top_p: 0.9, + typical_p: 0.9, + do_sample: true, seed: 0, - repetition_penalty: 1.0, - watermark: false, + repetition_penalty: 1.2, + watermark: true, }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: cmp::min(10, max_total_tokens - max_input_length), diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index cff7686f..0c5a7288 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -802,11 +802,11 @@ class CausalLM(Model): input_length = batch.input_length if self.is_optimized_for_gaudi and logits.shape[-2] > 1: next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser( - batch.input_ids[:, :token_idx_scalar], logits[:, input_length - 1: input_length, :].squeeze(-2) + batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2) ) else: next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser( - batch.input_ids[:, :token_idx_scalar], logits.squeeze(-2) + batch.input_ids, logits.squeeze(-2) ) batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens,