mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixing the logic.
This commit is contained in:
parent
56c630a425
commit
d77a31cd95
@ -152,11 +152,13 @@ def create_decode_state(
|
||||
):
|
||||
"""Create a decode state."""
|
||||
workspace_buffer = get_workspace(device)
|
||||
num_groups = num_heads // num_kv_heads
|
||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout="NHD",
|
||||
use_cuda_graph=False,
|
||||
use_tensor_cores=num_heads // num_kv_heads > 1,
|
||||
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
|
||||
use_tensor_cores=num_groups not in [1, 2, 4, 8],
|
||||
)
|
||||
|
||||
|
||||
@ -175,6 +177,7 @@ def create_decode_state_cuda_graphs(
|
||||
therefore stored as part of the state.
|
||||
"""
|
||||
workspace_buffer = get_workspace(device)
|
||||
num_groups = num_heads // num_kv_heads
|
||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout="NHD",
|
||||
@ -182,7 +185,8 @@ def create_decode_state_cuda_graphs(
|
||||
paged_kv_indices_buffer=block_tables,
|
||||
paged_kv_indptr_buffer=block_tables_ptr,
|
||||
paged_kv_last_page_len_buffer=last_page_len,
|
||||
use_tensor_cores=num_heads // num_kv_heads > 1,
|
||||
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
|
||||
use_tensor_cores=num_groups not in [1, 2, 4, 8],
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user