mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
faster
This commit is contained in:
parent
92a74ea036
commit
4b9ebb0a85
@ -147,7 +147,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if prefill:
|
if prefill:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[past_present_indices] = qkv[:, 1:]
|
layer_past[...] = qkv[:, 1:]
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
attn_output = torch.empty_like(qkv[:, 0])
|
||||||
@ -353,9 +353,10 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
prefill = True
|
prefill = True
|
||||||
|
|
||||||
# Create past tensor
|
# Create past tensor
|
||||||
|
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
||||||
past_key_values = hidden_states.new_empty(
|
past_key_values = hidden_states.new_empty(
|
||||||
(
|
(
|
||||||
pre_allocate_past_size,
|
len(input_ids),
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
2,
|
2,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
@ -389,6 +390,21 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
prefill,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if prefill:
|
||||||
|
present = past_key_values
|
||||||
|
# Create padded past tensor
|
||||||
|
past_key_values = hidden_states.new_empty(
|
||||||
|
(
|
||||||
|
pre_allocate_past_size,
|
||||||
|
len(self.layers),
|
||||||
|
2,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# We slice only once instead of at every layer
|
||||||
|
past_key_values[past_present_indices] = present
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states, past_key_values
|
||||||
|
@ -132,7 +132,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if prefill:
|
if prefill:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[past_present_indices] = qkv[:, 1:]
|
layer_past[...] = qkv[:, 1:]
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
attn_output = torch.empty_like(qkv[:, 0])
|
||||||
@ -362,9 +362,10 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
prefill = True
|
prefill = True
|
||||||
|
|
||||||
# Create past tensor
|
# Create past tensor
|
||||||
|
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
||||||
past_key_values = hidden_states.new_empty(
|
past_key_values = hidden_states.new_empty(
|
||||||
(
|
(
|
||||||
pre_allocate_past_size,
|
len(input_ids),
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
2,
|
2,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
@ -398,6 +399,21 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
prefill,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if prefill:
|
||||||
|
present = past_key_values
|
||||||
|
# Create padded past tensor
|
||||||
|
past_key_values = hidden_states.new_empty(
|
||||||
|
(
|
||||||
|
pre_allocate_past_size,
|
||||||
|
len(self.layers),
|
||||||
|
2,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# We slice only once instead of at every layer
|
||||||
|
past_key_values[past_present_indices] = present
|
||||||
|
|
||||||
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states, past_key_values
|
||||||
|
@ -158,7 +158,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if prefill:
|
if prefill:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[past_present_indices] = kv
|
layer_past[...] = kv
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
||||||
|
|
||||||
@ -295,7 +295,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if prefill:
|
if prefill:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[past_present_indices] = kv
|
layer_past[...] = kv
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
kv = (
|
kv = (
|
||||||
kv.unsqueeze(2)
|
kv.unsqueeze(2)
|
||||||
@ -629,9 +629,10 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
prefill = True
|
prefill = True
|
||||||
|
|
||||||
# Create past tensor
|
# Create past tensor
|
||||||
|
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
||||||
past_key_values = hidden_states.new_empty(
|
past_key_values = hidden_states.new_empty(
|
||||||
(
|
(
|
||||||
pre_allocate_past_size,
|
len(input_ids),
|
||||||
len(self.h),
|
len(self.h),
|
||||||
*self.cache_size,
|
*self.cache_size,
|
||||||
)
|
)
|
||||||
@ -663,6 +664,19 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
prefill,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if prefill:
|
||||||
|
present = past_key_values
|
||||||
|
# Create padded past tensor
|
||||||
|
past_key_values = hidden_states.new_empty(
|
||||||
|
(
|
||||||
|
pre_allocate_past_size,
|
||||||
|
len(self.h),
|
||||||
|
*self.cache_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# We slice only once instead of at every layer
|
||||||
|
past_key_values[past_present_indices] = present
|
||||||
|
|
||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states, past_key_values
|
||||||
|
@ -172,7 +172,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if prefill:
|
if prefill:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[past_present_indices] = key_value
|
layer_past[...] = key_value
|
||||||
# Expand from 1 to num_heads
|
# Expand from 1 to num_heads
|
||||||
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
||||||
|
|
||||||
@ -372,8 +372,9 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
prefill = True
|
prefill = True
|
||||||
|
|
||||||
# Create past tensor
|
# Create past tensor
|
||||||
|
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
||||||
past_key_values = hidden_states.new_zeros(
|
past_key_values = hidden_states.new_zeros(
|
||||||
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
|
(len(input_ids), len(self.h), 2, 1, self.head_size)
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
@ -394,6 +395,15 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
prefill,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if prefill:
|
||||||
|
present = past_key_values
|
||||||
|
# Create padded past tensor
|
||||||
|
past_key_values = hidden_states.new_empty(
|
||||||
|
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
|
||||||
|
)
|
||||||
|
# We slice only once instead of at every layer
|
||||||
|
past_key_values[past_present_indices] = present
|
||||||
|
|
||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states, past_key_values
|
||||||
|
Loading…
Reference in New Issue
Block a user