From ab20142c146ae8bfd51146c0d5226e91911c6cae Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 24 Apr 2023 07:24:35 +0100 Subject: [PATCH] trim to new max input length in filter() --- server/text_generation_server/models/causal_lm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index cc89049e..1db5abce 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -179,7 +179,7 @@ class CausalLMBatch(Batch): position_ids = self.position_ids[keep_indices] self.attention_mask = self.attention_mask[ keep_indices, - -(self.padding_right_offset + self.max_input_length): + -(self.padding_right_offset + max_input_length): (self.attention_mask.shape[1] - self.padding_right_offset) + new_padding_right_offset, ] @@ -188,7 +188,7 @@ class CausalLMBatch(Batch): self.past_key_values = [list(layer) for layer in self.past_key_values] # Update tensors in-place to allow incremental garbage collection - past_kv_length = self.max_input_length - 1 + past_kv_length = max_input_length - 1 for layer in self.past_key_values: past_keys, past_values = layer if len(past_keys.shape) == 3: