From 27b3a144f77dffd380910031d62004daed93437b Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 21 Apr 2023 20:25:13 +0200 Subject: [PATCH] fix(server): fix flash batch filtering --- server/text_generation_server/models/flash_causal_lm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index f332ab51..9cd9ed89 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -188,9 +188,10 @@ class FlashCausalLMBatch(Batch): position_ids.append(self.position_ids[idx]) cu_seqlens.append(cumulative_length + request_input_length) max_seqlen = max(max_seqlen, request_input_length) + # True index for past + past_key_values.append(self.past_key_values[2 * idx]) + if not single_request: - # True index for past - past_key_values.append(self.past_key_values[2 * idx]) # Add one padding past_key_values.append(self.past_pad) @@ -209,7 +210,7 @@ class FlashCausalLMBatch(Batch): if single_request: # Preallocate tensor for bs = 1 case past_key_values = torch.nn.functional.pad( - self.past_key_values[0], + past_key_values[0], ( 0, 0,