trim to new max input length in filter()

This commit is contained in:
Nick Hill 2023-04-24 07:24:35 +01:00
parent 0b1d0010a4
commit ab20142c14

View File

@ -179,7 +179,7 @@ class CausalLMBatch(Batch):
position_ids = self.position_ids[keep_indices]
self.attention_mask = self.attention_mask[
keep_indices,
-(self.padding_right_offset + self.max_input_length):
-(self.padding_right_offset + max_input_length):
(self.attention_mask.shape[1] - self.padding_right_offset) + new_padding_right_offset,
]
@ -188,7 +188,7 @@ class CausalLMBatch(Batch):
self.past_key_values = [list(layer) for layer in self.past_key_values]
# Update tensors in-place to allow incremental garbage collection
past_kv_length = self.max_input_length - 1
past_kv_length = max_input_length - 1
for layer in self.past_key_values:
past_keys, past_values = layer
if len(past_keys.shape) == 3: