diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 36db12d0..99c490d5 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -2,6 +2,7 @@ import os import torch from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import FLASH_DECODING +from text_generation_server.layers.attention import Seqlen from loguru import logger major, minor = torch.cuda.get_device_capability() @@ -45,8 +46,7 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - cu_seqlen_q: torch.Tensor, - cu_seqlen_k: torch.Tensor, + input_lengths: Seqlen, max_s: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -70,7 +70,7 @@ def paged_attention( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - input_lengths = cu_seqlen_k + input_lengths = input_lengths.input_lengths # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use