This commit is contained in:
Nicolas Patry 2024-06-25 13:10:20 +00:00
parent fcbc6876c0
commit 212a59544b

View File

@ -1,6 +1,6 @@
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
@ -62,7 +62,8 @@ def paged_attention(
#
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
# block_size = value_cache.shape[3]
block_size = BLOCK_SIZE
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = cu_seqlen_k
@ -87,7 +88,7 @@ def paged_attention(
key_cache,
value_cache,
None,
cu_seqlen_k,
cu_seqlen_q,
cu_seqlen_k,
None,
block_tables,