mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
faster
This commit is contained in:
parent
92a74ea036
commit
4b9ebb0a85
@ -147,7 +147,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
# Prefill
|
||||
if prefill:
|
||||
# Copy to layer past
|
||||
layer_past[past_present_indices] = qkv[:, 1:]
|
||||
layer_past[...] = qkv[:, 1:]
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
@ -353,9 +353,10 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
prefill = True
|
||||
|
||||
# Create past tensor
|
||||
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
||||
past_key_values = hidden_states.new_empty(
|
||||
(
|
||||
pre_allocate_past_size,
|
||||
len(input_ids),
|
||||
len(self.layers),
|
||||
2,
|
||||
self.num_heads,
|
||||
@ -389,6 +390,21 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
prefill,
|
||||
)
|
||||
|
||||
if prefill:
|
||||
present = past_key_values
|
||||
# Create padded past tensor
|
||||
past_key_values = hidden_states.new_empty(
|
||||
(
|
||||
pre_allocate_past_size,
|
||||
len(self.layers),
|
||||
2,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
)
|
||||
)
|
||||
# We slice only once instead of at every layer
|
||||
past_key_values[past_present_indices] = present
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states, past_key_values
|
||||
|
@ -132,7 +132,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
# Prefill
|
||||
if prefill:
|
||||
# Copy to layer past
|
||||
layer_past[past_present_indices] = qkv[:, 1:]
|
||||
layer_past[...] = qkv[:, 1:]
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
@ -362,9 +362,10 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
prefill = True
|
||||
|
||||
# Create past tensor
|
||||
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
||||
past_key_values = hidden_states.new_empty(
|
||||
(
|
||||
pre_allocate_past_size,
|
||||
len(input_ids),
|
||||
len(self.layers),
|
||||
2,
|
||||
self.num_heads,
|
||||
@ -398,6 +399,21 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
prefill,
|
||||
)
|
||||
|
||||
if prefill:
|
||||
present = past_key_values
|
||||
# Create padded past tensor
|
||||
past_key_values = hidden_states.new_empty(
|
||||
(
|
||||
pre_allocate_past_size,
|
||||
len(self.layers),
|
||||
2,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
)
|
||||
)
|
||||
# We slice only once instead of at every layer
|
||||
past_key_values[past_present_indices] = present
|
||||
|
||||
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
||||
|
||||
return hidden_states, past_key_values
|
||||
|
@ -158,7 +158,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
# Prefill
|
||||
if prefill:
|
||||
# Copy to layer past
|
||||
layer_past[past_present_indices] = kv
|
||||
layer_past[...] = kv
|
||||
# Expand to query shape
|
||||
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
@ -295,7 +295,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
# Prefill
|
||||
if prefill:
|
||||
# Copy to layer past
|
||||
layer_past[past_present_indices] = kv
|
||||
layer_past[...] = kv
|
||||
# Expand to query shape
|
||||
kv = (
|
||||
kv.unsqueeze(2)
|
||||
@ -629,9 +629,10 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
prefill = True
|
||||
|
||||
# Create past tensor
|
||||
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
||||
past_key_values = hidden_states.new_empty(
|
||||
(
|
||||
pre_allocate_past_size,
|
||||
len(input_ids),
|
||||
len(self.h),
|
||||
*self.cache_size,
|
||||
)
|
||||
@ -663,6 +664,19 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||
prefill,
|
||||
)
|
||||
|
||||
if prefill:
|
||||
present = past_key_values
|
||||
# Create padded past tensor
|
||||
past_key_values = hidden_states.new_empty(
|
||||
(
|
||||
pre_allocate_past_size,
|
||||
len(self.h),
|
||||
*self.cache_size,
|
||||
)
|
||||
)
|
||||
# We slice only once instead of at every layer
|
||||
past_key_values[past_present_indices] = present
|
||||
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
|
||||
return hidden_states, past_key_values
|
||||
|
@ -172,7 +172,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||
# Prefill
|
||||
if prefill:
|
||||
# Copy to layer past
|
||||
layer_past[past_present_indices] = key_value
|
||||
layer_past[...] = key_value
|
||||
# Expand from 1 to num_heads
|
||||
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
@ -372,8 +372,9 @@ class FlashSantacoderModel(nn.Module):
|
||||
prefill = True
|
||||
|
||||
# Create past tensor
|
||||
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
||||
past_key_values = hidden_states.new_zeros(
|
||||
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
|
||||
(len(input_ids), len(self.h), 2, 1, self.head_size)
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
@ -394,6 +395,15 @@ class FlashSantacoderModel(nn.Module):
|
||||
prefill,
|
||||
)
|
||||
|
||||
if prefill:
|
||||
present = past_key_values
|
||||
# Create padded past tensor
|
||||
past_key_values = hidden_states.new_empty(
|
||||
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
|
||||
)
|
||||
# We slice only once instead of at every layer
|
||||
past_key_values[past_present_indices] = present
|
||||
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
|
||||
return hidden_states, past_key_values
|
||||
|
Loading…
Reference in New Issue
Block a user