This commit is contained in:
OlivierDehaene 2023-06-12 16:25:23 +02:00
parent 92a74ea036
commit 4b9ebb0a85
4 changed files with 65 additions and 9 deletions

View File

@ -147,7 +147,7 @@ class FlashLlamaAttention(torch.nn.Module):
# Prefill
if prefill:
# Copy to layer past
layer_past[past_present_indices] = qkv[:, 1:]
layer_past[...] = qkv[:, 1:]
# output
attn_output = torch.empty_like(qkv[:, 0])
@ -353,9 +353,10 @@ class FlashLlamaModel(torch.nn.Module):
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(input_ids),
len(self.layers),
2,
self.num_heads,
@ -389,6 +390,21 @@ class FlashLlamaModel(torch.nn.Module):
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states, past_key_values

View File

@ -132,7 +132,7 @@ class FlashNeoxAttention(torch.nn.Module):
# Prefill
if prefill:
# Copy to layer past
layer_past[past_present_indices] = qkv[:, 1:]
layer_past[...] = qkv[:, 1:]
# output
attn_output = torch.empty_like(qkv[:, 0])
@ -362,9 +362,10 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(input_ids),
len(self.layers),
2,
self.num_heads,
@ -398,6 +399,21 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
return hidden_states, past_key_values

View File

@ -158,7 +158,7 @@ class FlashRWAttention(torch.nn.Module):
# Prefill
if prefill:
# Copy to layer past
layer_past[past_present_indices] = kv
layer_past[...] = kv
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
@ -295,7 +295,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# Prefill
if prefill:
# Copy to layer past
layer_past[past_present_indices] = kv
layer_past[...] = kv
# Expand to query shape
kv = (
kv.unsqueeze(2)
@ -629,9 +629,10 @@ class FlashRWModel(FlashRWPreTrainedModel):
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(input_ids),
len(self.h),
*self.cache_size,
)
@ -663,6 +664,19 @@ class FlashRWModel(FlashRWPreTrainedModel):
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.h),
*self.cache_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states, past_key_values

View File

@ -172,7 +172,7 @@ class FlashMQAttention(torch.nn.Module):
# Prefill
if prefill:
# Copy to layer past
layer_past[past_present_indices] = key_value
layer_past[...] = key_value
# Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
@ -372,8 +372,9 @@ class FlashSantacoderModel(nn.Module):
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_zeros(
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
(len(input_ids), len(self.h), 2, 1, self.head_size)
)
# Decode
else:
@ -394,6 +395,15 @@ class FlashSantacoderModel(nn.Module):
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states, past_key_values