fix vllm import error

This commit is contained in:
Zeyu Li 2023-12-30 14:26:37 +08:00
parent 630800eed3
commit ad7f839673

View File

@ -1,8 +1,8 @@
import torch import torch
# vllm imports # vllm imports
from vllm import cache_ops from vllm._C import cache_ops
from vllm import attention_ops from vllm._C import ops
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
@ -56,7 +56,7 @@ def attention(
# to parallelize. # to parallelize.
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512 use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
if use_v1: if use_v1:
attention_ops.paged_attention_v1( ops.paged_attention_v1(
out, out,
query, query,
key_cache, key_cache,
@ -83,7 +83,7 @@ def attention(
device=out.device, device=out.device,
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
attention_ops.paged_attention_v2( ops.paged_attention_v2(
out, out,
exp_sums, exp_sums,
max_logits, max_logits,