diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 2ea88e9d..993e1e2a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -147,7 +147,7 @@ class FlashLlamaAttention(torch.nn.Module): # Prefill if prefill: # Copy to layer past - layer_past[past_present_indices] = qkv[:, 1:] + layer_past[...] = qkv[:, 1:] # output attn_output = torch.empty_like(qkv[:, 0]) @@ -353,9 +353,10 @@ class FlashLlamaModel(torch.nn.Module): prefill = True # Create past tensor + # We create a tensor of the same size as input_ids as we don't want to slice at every layer past_key_values = hidden_states.new_empty( ( - pre_allocate_past_size, + len(input_ids), len(self.layers), 2, self.num_heads, @@ -389,6 +390,21 @@ class FlashLlamaModel(torch.nn.Module): prefill, ) + if prefill: + present = past_key_values + # Create padded past tensor + past_key_values = hidden_states.new_empty( + ( + pre_allocate_past_size, + len(self.layers), + 2, + self.num_heads, + self.head_size, + ) + ) + # We slice only once instead of at every layer + past_key_values[past_present_indices] = present + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states, past_key_values diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 21362b22..4d42bab6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -132,7 +132,7 @@ class FlashNeoxAttention(torch.nn.Module): # Prefill if prefill: # Copy to layer past - layer_past[past_present_indices] = qkv[:, 1:] + layer_past[...] = qkv[:, 1:] # output attn_output = torch.empty_like(qkv[:, 0]) @@ -362,9 +362,10 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): prefill = True # Create past tensor + # We create a tensor of the same size as input_ids as we don't want to slice at every layer past_key_values = hidden_states.new_empty( ( - pre_allocate_past_size, + len(input_ids), len(self.layers), 2, self.num_heads, @@ -398,6 +399,21 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): prefill, ) + if prefill: + present = past_key_values + # Create padded past tensor + past_key_values = hidden_states.new_empty( + ( + pre_allocate_past_size, + len(self.layers), + 2, + self.num_heads, + self.head_size, + ) + ) + # We slice only once instead of at every layer + past_key_values[past_present_indices] = present + hidden_states, _ = self.final_layer_norm(hidden_states, residual) return hidden_states, past_key_values diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index e7665c8d..d9388caf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -158,7 +158,7 @@ class FlashRWAttention(torch.nn.Module): # Prefill if prefill: # Copy to layer past - layer_past[past_present_indices] = kv + layer_past[...] = kv # Expand to query shape kv = kv.expand(-1, 2, self.num_heads, self.head_size) @@ -295,7 +295,7 @@ class FlashRWLargeAttention(torch.nn.Module): # Prefill if prefill: # Copy to layer past - layer_past[past_present_indices] = kv + layer_past[...] = kv # Expand to query shape kv = ( kv.unsqueeze(2) @@ -629,9 +629,10 @@ class FlashRWModel(FlashRWPreTrainedModel): prefill = True # Create past tensor + # We create a tensor of the same size as input_ids as we don't want to slice at every layer past_key_values = hidden_states.new_empty( ( - pre_allocate_past_size, + len(input_ids), len(self.h), *self.cache_size, ) @@ -663,6 +664,19 @@ class FlashRWModel(FlashRWPreTrainedModel): prefill, ) + if prefill: + present = past_key_values + # Create padded past tensor + past_key_values = hidden_states.new_empty( + ( + pre_allocate_past_size, + len(self.h), + *self.cache_size, + ) + ) + # We slice only once instead of at every layer + past_key_values[past_present_indices] = present + hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states, past_key_values diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 2ccdf045..a15c6050 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -172,7 +172,7 @@ class FlashMQAttention(torch.nn.Module): # Prefill if prefill: # Copy to layer past - layer_past[past_present_indices] = key_value + layer_past[...] = key_value # Expand from 1 to num_heads key_value = key_value.expand(-1, 2, self.num_heads, self.head_size) @@ -372,8 +372,9 @@ class FlashSantacoderModel(nn.Module): prefill = True # Create past tensor + # We create a tensor of the same size as input_ids as we don't want to slice at every layer past_key_values = hidden_states.new_zeros( - (pre_allocate_past_size, len(self.h), 2, 1, self.head_size) + (len(input_ids), len(self.h), 2, 1, self.head_size) ) # Decode else: @@ -394,6 +395,15 @@ class FlashSantacoderModel(nn.Module): prefill, ) + if prefill: + present = past_key_values + # Create padded past tensor + past_key_values = hidden_states.new_empty( + (pre_allocate_past_size, len(self.h), 2, 1, self.head_size) + ) + # We slice only once instead of at every layer + past_key_values[past_present_indices] = present + hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states, past_key_values