diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index ee1bd01f..de9b22da 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -594,7 +594,7 @@ class FlashLlamaModel(torch.nn.Module): residual = None for i, layer in enumerate(self.layers): - # We added padding that now need to slice + # We added padding that we now need to slice layer_past_key_values = ( past_key_values[i] if slice_past_index is None diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 60545848..cc9b292f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -657,7 +657,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): residual = None for i, layer in enumerate(self.layers): - # We added padding that now need to slice + # We added padding that we now need to slice layer_past_key_values = ( past_key_values[i] if slice_past_index is None diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 736f896f..71182f8d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -520,7 +520,7 @@ class FlashSantacoderModel(nn.Module): residual = None for i, layer in enumerate(self.h): - # We added padding that now need to slice + # We added padding that we now need to slice layer_past_key_values = ( past_key_values[i] if slice_past_index is None diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7be13121..7e048b74 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -404,8 +404,7 @@ class FlashCausalLM(Model): # Shortcut when batch_size == 1 if len(batch) == 1: input_ids = batch.input_ids[0].view(-1) - # Slice to remove extra padding - # past_key_values = batch.past_key_values[:, :batch.input_lengths[0]] if batch.past_key_values is not None else None + # No need to slice as flash attention will take care of it with cu_seqlens past_key_values = batch.past_key_values else: # Concatenate tensors @@ -454,13 +453,7 @@ class FlashCausalLM(Model): ) # Set in batch in case it needs to be used later in concatenate() batch.past_pad = self.past_pad - if len(batch) == 1: - # Preallocate tensor for bs = 1 case - batch.past_key_values = torch.nn.functional.pad( - present, - (0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens), - ) - else: + if len(batch) != 1: # Add padding after each sequence # This will have the correct shape after the final past_key_values concatenation before the model # forward