From bf5990ee9ecc5de0a94959fe2523fd12d8d91b6e Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 5 May 2023 17:37:17 +0200 Subject: [PATCH] more explicit --- server/text_generation_server/models/flash_causal_lm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 76265217..f4849f36 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -602,7 +602,7 @@ class FlashCausalLM(Model): else: # Decode mode # out is of shape [batch_size, vocab_size] - logits = out[i].unsqueeze(0) + logits = out[i].view(1, -1) all_input_ids_tensor = batch.all_input_ids_tensor[i] @@ -612,7 +612,7 @@ class FlashCausalLM(Model): ) # Add to all_input_ids_tensor - next_token_id_squeezed = next_token_id.squeeze() + next_token_id_squeezed = next_token_id.view(1) all_input_ids_tensor[input_length] = next_token_id_squeezed # Set values @@ -630,10 +630,10 @@ class FlashCausalLM(Model): # Get prefill logprobs prefill_logprobs_tensor = torch.log_softmax(out, -1) prefill_logprobs = torch.gather( - prefill_logprobs_tensor, 1, prefill_tokens_indices.unsqueeze(1) + prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1) ) # GPU <-> CPU sync - prefill_logprobs = prefill_logprobs.squeeze(1).tolist() + prefill_logprobs = prefill_logprobs.view(-1).tolist() # GPU <-> CPU sync next_token_logprobs = next_token_logprobs.tolist()