diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index 9479b6067..7095593b9 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -84,16 +84,16 @@ def use_prefill_with_paged_kv_state( token = prefill_with_paged_kv_state.set(state) try: state.plan( - qo_indptr=cu_seqlens, - paged_kv_indptr=indptr, - paged_kv_indices=block_tables, - paged_kv_last_page_len=last_page_len, - num_qo_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_size, - kv_data_type=kv_dtype, + cu_seqlens, + indptr, + block_tables, + last_page_len, + num_heads, + num_kv_heads, + head_size, + page_size, q_data_type=q_dtype, - page_size=page_size, + kv_data_type=kv_dtype, ) yield finally: