mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
tiny simplification
This commit is contained in:
parent
e360cf92cf
commit
41e0310ef7
@ -407,17 +407,17 @@ class FlashCausalLM(Model):
|
||||
past_key_values,
|
||||
)
|
||||
|
||||
# Initialize past padding tensor
|
||||
if self.past_pad is None:
|
||||
self.past_pad = present.new_zeros(present.shape[0], 1, *present.shape[2:])
|
||||
|
||||
# Initialize past_key_values in prefill
|
||||
if batch.past_key_values is None:
|
||||
# Initialize past padding tensor
|
||||
if self.past_pad is None:
|
||||
self.past_pad = present.new_zeros(present.shape[0], 1, *present.shape[2:])
|
||||
# Set in batch in case it needs to be used later in concatenate()
|
||||
batch.past_pad = self.past_pad
|
||||
if len(batch) == 1:
|
||||
# Preallocate tensor for bs = 1 case
|
||||
batch.past_key_values = torch.nn.functional.pad(
|
||||
present, (0, 0, 0, 0, 0, 0, 0, batch.requests[0].stopping_parameters.max_new_tokens)
|
||||
present, (0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens)
|
||||
)
|
||||
else:
|
||||
batch.past_key_values = [None, self.past_pad] * len(batch)
|
||||
|
Loading…
Reference in New Issue
Block a user