From 41e0310ef7567c2e59afab7235265cd2e908e891 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 20 Apr 2023 11:27:20 -0700 Subject: [PATCH] tiny simplification --- .../text_generation_server/models/flash_causal_lm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c99fd629..2843f273 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -407,17 +407,17 @@ class FlashCausalLM(Model): past_key_values, ) - # Initialize past padding tensor - if self.past_pad is None: - self.past_pad = present.new_zeros(present.shape[0], 1, *present.shape[2:]) - # Initialize past_key_values in prefill if batch.past_key_values is None: + # Initialize past padding tensor + if self.past_pad is None: + self.past_pad = present.new_zeros(present.shape[0], 1, *present.shape[2:]) + # Set in batch in case it needs to be used later in concatenate() batch.past_pad = self.past_pad if len(batch) == 1: # Preallocate tensor for bs = 1 case batch.past_key_values = torch.nn.functional.pad( - present, (0, 0, 0, 0, 0, 0, 0, batch.requests[0].stopping_parameters.max_new_tokens) + present, (0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens) ) else: batch.past_key_values = [None, self.past_pad] * len(batch)