mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
tiny simplification
This commit is contained in:
parent
e360cf92cf
commit
41e0310ef7
@ -407,17 +407,17 @@ class FlashCausalLM(Model):
|
|||||||
past_key_values,
|
past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize past_key_values in prefill
|
||||||
|
if batch.past_key_values is None:
|
||||||
# Initialize past padding tensor
|
# Initialize past padding tensor
|
||||||
if self.past_pad is None:
|
if self.past_pad is None:
|
||||||
self.past_pad = present.new_zeros(present.shape[0], 1, *present.shape[2:])
|
self.past_pad = present.new_zeros(present.shape[0], 1, *present.shape[2:])
|
||||||
|
# Set in batch in case it needs to be used later in concatenate()
|
||||||
# Initialize past_key_values in prefill
|
|
||||||
if batch.past_key_values is None:
|
|
||||||
batch.past_pad = self.past_pad
|
batch.past_pad = self.past_pad
|
||||||
if len(batch) == 1:
|
if len(batch) == 1:
|
||||||
# Preallocate tensor for bs = 1 case
|
# Preallocate tensor for bs = 1 case
|
||||||
batch.past_key_values = torch.nn.functional.pad(
|
batch.past_key_values = torch.nn.functional.pad(
|
||||||
present, (0, 0, 0, 0, 0, 0, 0, batch.requests[0].stopping_parameters.max_new_tokens)
|
present, (0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
batch.past_key_values = [None, self.past_pad] * len(batch)
|
batch.past_key_values = [None, self.past_pad] * len(batch)
|
||||||
|
Loading…
Reference in New Issue
Block a user