Fix mistral with length > window_size for long prefills (rotary doesn't

create long enough cos, sin).
This commit is contained in:
Nicolas Patry 2024-02-16 18:08:02 +00:00
parent 4139054b82
commit 2804a74276

View File

@ -460,8 +460,8 @@ class BaseFlashMistral(FlashCausalLM):
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
if self.model.max_past is not None:
max_s = min(self.model.max_past, max_s)
# if self.model.max_past is not None:
# max_s = min(self.model.max_past, max_s)
bs = input_ids.shape[0]
padded_bs = bs