fix(server): fix flash causal

This commit is contained in:
OlivierDehaene 2023-04-21 19:41:52 +02:00
parent afc5b999d0
commit 91c87b8013

View File

@ -453,7 +453,10 @@ class FlashCausalLM(Model):
)
# Set in batch in case it needs to be used later in concatenate()
batch.past_pad = self.past_pad
if len(batch) != 1:
if len(batch) == 1:
# present is already pre-padded
batch.past_key_values = present
else:
# Add padding after each sequence
# This will have the correct shape after the final past_key_values concatenation before the model
# forward