diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index e1ef62c5..dbb7ee96 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -156,7 +156,7 @@ def create_decode_state( workspace_buffer, kv_layout="NHD", use_cuda_graph=False, - use_tensor_cores=num_heads // num_kv_heads > 4, + use_tensor_cores=num_heads // num_kv_heads > 1, ) @@ -182,7 +182,7 @@ def create_decode_state_cuda_graphs( paged_kv_indices_buffer=block_tables, paged_kv_indptr_buffer=block_tables_ptr, paged_kv_last_page_len_buffer=last_page_len, - use_tensor_cores=num_heads // num_kv_heads > 4, + use_tensor_cores=num_heads // num_kv_heads > 1, )