diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index 0d3cccd69..ea5c45582 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -1,5 +1,6 @@ import torch from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import FLASH_DECODING _PARTITION_SIZE = 512 @@ -125,7 +126,9 @@ def attention( else: from vllm._C import ops - use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) + use_v1 = max_s <= 8192 and ( + max_num_partitions == 1 or num_seqs * num_heads > 512 + ) if use_v1: ops.paged_attention_v1( out,