diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index 4b12744c..af9851dc 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -1,8 +1,8 @@ import torch # vllm imports -from vllm import cache_ops -from vllm import attention_ops +from vllm._C import cache_ops +from vllm._C import ops _PARTITION_SIZE = 512 @@ -56,7 +56,7 @@ def attention( # to parallelize. use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512 if use_v1: - attention_ops.paged_attention_v1( + ops.paged_attention_v1( out, query, key_cache, @@ -83,7 +83,7 @@ def attention( device=out.device, ) max_logits = torch.empty_like(exp_sums) - attention_ops.paged_attention_v2( + ops.paged_attention_v2( out, exp_sums, max_logits,