diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 265255dd1..22e0ada18 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1469,8 +1469,6 @@ class FlashCausalLM(Model): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - # TODO - # input_lengths = input_lengths + prefix_lens_tensor if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, @@ -1481,7 +1479,7 @@ class FlashCausalLM(Model): block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, input_lengths=batch.input_lengths, - input_lengths_tensor=input_lengths, + input_lengths_tensor=input_lengths + prefix_lens_tensor, prefix_lens=batch.prefix_lens, prefix_lens_tensor=prefix_lens_tensor, ):