Fixed flashinfer version.

This commit is contained in:
Nicolas Patry 2024-08-27 15:00:22 +02:00
parent bb9769ed42
commit 55d984d730
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863

View File

@ -1469,8 +1469,6 @@ class FlashCausalLM(Model):
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
# TODO
# input_lengths = input_lengths + prefix_lens_tensor
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
@ -1481,7 +1479,7 @@ class FlashCausalLM(Model):
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths,
input_lengths_tensor=input_lengths + prefix_lens_tensor,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
):