From 91c87b80136159714817d4ae1b16776b9a085294 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 21 Apr 2023 19:41:52 +0200 Subject: [PATCH] fix(server): fix flash causal --- server/text_generation_server/models/flash_causal_lm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7e048b74..c44dd57d 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -453,7 +453,10 @@ class FlashCausalLM(Model): ) # Set in batch in case it needs to be used later in concatenate() batch.past_pad = self.past_pad - if len(batch) != 1: + if len(batch) == 1: + # present is already pre-padded + batch.past_key_values = present + else: # Add padding after each sequence # This will have the correct shape after the final past_key_values concatenation before the model # forward