diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index cc89049e..1db5abce 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -179,7 +179,7 @@ class CausalLMBatch(Batch): position_ids = self.position_ids[keep_indices] self.attention_mask = self.attention_mask[ keep_indices, - -(self.padding_right_offset + self.max_input_length): + -(self.padding_right_offset + max_input_length): (self.attention_mask.shape[1] - self.padding_right_offset) + new_padding_right_offset, ] @@ -188,7 +188,7 @@ class CausalLMBatch(Batch): self.past_key_values = [list(layer) for layer in self.past_key_values] # Update tensors in-place to allow incremental garbage collection - past_kv_length = self.max_input_length - 1 + past_kv_length = max_input_length - 1 for layer in self.past_key_values: past_keys, past_values = layer if len(past_keys.shape) == 3: