mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
fix(server): cleanup new flash past_key_values logic (#217)
This commit is contained in:
parent
db4cb5e4ed
commit
afc5b999d0
@ -594,7 +594,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
# We added padding that now need to slice
|
# We added padding that we now need to slice
|
||||||
layer_past_key_values = (
|
layer_past_key_values = (
|
||||||
past_key_values[i]
|
past_key_values[i]
|
||||||
if slice_past_index is None
|
if slice_past_index is None
|
||||||
|
@ -657,7 +657,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
# We added padding that now need to slice
|
# We added padding that we now need to slice
|
||||||
layer_past_key_values = (
|
layer_past_key_values = (
|
||||||
past_key_values[i]
|
past_key_values[i]
|
||||||
if slice_past_index is None
|
if slice_past_index is None
|
||||||
|
@ -520,7 +520,7 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.h):
|
for i, layer in enumerate(self.h):
|
||||||
# We added padding that now need to slice
|
# We added padding that we now need to slice
|
||||||
layer_past_key_values = (
|
layer_past_key_values = (
|
||||||
past_key_values[i]
|
past_key_values[i]
|
||||||
if slice_past_index is None
|
if slice_past_index is None
|
||||||
|
@ -404,8 +404,7 @@ class FlashCausalLM(Model):
|
|||||||
# Shortcut when batch_size == 1
|
# Shortcut when batch_size == 1
|
||||||
if len(batch) == 1:
|
if len(batch) == 1:
|
||||||
input_ids = batch.input_ids[0].view(-1)
|
input_ids = batch.input_ids[0].view(-1)
|
||||||
# Slice to remove extra padding
|
# No need to slice as flash attention will take care of it with cu_seqlens
|
||||||
# past_key_values = batch.past_key_values[:, :batch.input_lengths[0]] if batch.past_key_values is not None else None
|
|
||||||
past_key_values = batch.past_key_values
|
past_key_values = batch.past_key_values
|
||||||
else:
|
else:
|
||||||
# Concatenate tensors
|
# Concatenate tensors
|
||||||
@ -454,13 +453,7 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
# Set in batch in case it needs to be used later in concatenate()
|
# Set in batch in case it needs to be used later in concatenate()
|
||||||
batch.past_pad = self.past_pad
|
batch.past_pad = self.past_pad
|
||||||
if len(batch) == 1:
|
if len(batch) != 1:
|
||||||
# Preallocate tensor for bs = 1 case
|
|
||||||
batch.past_key_values = torch.nn.functional.pad(
|
|
||||||
present,
|
|
||||||
(0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Add padding after each sequence
|
# Add padding after each sequence
|
||||||
# This will have the correct shape after the final past_key_values concatenation before the model
|
# This will have the correct shape after the final past_key_values concatenation before the model
|
||||||
# forward
|
# forward
|
||||||
|
Loading…
Reference in New Issue
Block a user