tiny simplification

This commit is contained in:
Nick Hill 2023-04-20 11:27:20 -07:00
parent e360cf92cf
commit 41e0310ef7

View File

@ -407,17 +407,17 @@ class FlashCausalLM(Model):
past_key_values,
)
# Initialize past padding tensor
if self.past_pad is None:
self.past_pad = present.new_zeros(present.shape[0], 1, *present.shape[2:])
# Initialize past_key_values in prefill
if batch.past_key_values is None:
# Initialize past padding tensor
if self.past_pad is None:
self.past_pad = present.new_zeros(present.shape[0], 1, *present.shape[2:])
# Set in batch in case it needs to be used later in concatenate()
batch.past_pad = self.past_pad
if len(batch) == 1:
# Preallocate tensor for bs = 1 case
batch.past_key_values = torch.nn.functional.pad(
present, (0, 0, 0, 0, 0, 0, 0, batch.requests[0].stopping_parameters.max_new_tokens)
present, (0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens)
)
else:
batch.past_key_values = [None, self.past_pad] * len(batch)