From 212a59544b348a69e6692e1e2dc0dab37fa3d645 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 25 Jun 2024 13:10:20 +0000 Subject: [PATCH] Update? --- server/text_generation_server/layers/attention/cuda.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index d87d9199..e0f09847 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,6 +1,6 @@ import torch from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import FLASH_DECODING +from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 @@ -62,7 +62,8 @@ def paged_attention( # # value_cache => [num_blocks, num_heads, head_size, block_size] - block_size = value_cache.shape[3] + # block_size = value_cache.shape[3] + block_size = BLOCK_SIZE num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE input_lengths = cu_seqlen_k @@ -87,7 +88,7 @@ def paged_attention( key_cache, value_cache, None, - cu_seqlen_k, + cu_seqlen_q, cu_seqlen_k, None, block_tables,