Fixing the logic.

This commit is contained in:
Nicolas Patry 2024-09-24 14:42:01 +02:00
parent 56c630a425
commit d77a31cd95
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863

View File

@ -152,11 +152,13 @@ def create_decode_state(
):
"""Create a decode state."""
workspace_buffer = get_workspace(device)
num_groups = num_heads // num_kv_heads
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=False,
use_tensor_cores=num_heads // num_kv_heads > 1,
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores=num_groups not in [1, 2, 4, 8],
)
@ -175,6 +177,7 @@ def create_decode_state_cuda_graphs(
therefore stored as part of the state.
"""
workspace_buffer = get_workspace(device)
num_groups = num_heads // num_kv_heads
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
@ -182,7 +185,8 @@ def create_decode_state_cuda_graphs(
paged_kv_indices_buffer=block_tables,
paged_kv_indptr_buffer=block_tables_ptr,
paged_kv_last_page_len_buffer=last_page_len,
use_tensor_cores=num_heads // num_kv_heads > 1,
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores=num_groups not in [1, 2, 4, 8],
)