Let it work.

This commit is contained in:
Nicolas Patry 2024-02-19 10:06:56 +00:00
parent 2804a74276
commit b189342170

View File

@ -460,8 +460,11 @@ class BaseFlashMistral(FlashCausalLM):
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
# if self.model.max_past is not None: if cu_seqlen_prefill is None and self.model.max_past is not None:
# max_s = min(self.model.max_past, max_s) # In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.model.max_past, max_s)
bs = input_ids.shape[0] bs = input_ids.shape[0]
padded_bs = bs padded_bs = bs