mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix vllm import error
This commit is contained in:
parent
630800eed3
commit
ad7f839673
@ -1,8 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
# vllm imports
|
# vllm imports
|
||||||
from vllm import cache_ops
|
from vllm._C import cache_ops
|
||||||
from vllm import attention_ops
|
from vllm._C import ops
|
||||||
|
|
||||||
_PARTITION_SIZE = 512
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
@ -56,7 +56,7 @@ def attention(
|
|||||||
# to parallelize.
|
# to parallelize.
|
||||||
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
|
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||||
if use_v1:
|
if use_v1:
|
||||||
attention_ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
key_cache,
|
||||||
@ -83,7 +83,7 @@ def attention(
|
|||||||
device=out.device,
|
device=out.device,
|
||||||
)
|
)
|
||||||
max_logits = torch.empty_like(exp_sums)
|
max_logits = torch.empty_like(exp_sums)
|
||||||
attention_ops.paged_attention_v2(
|
ops.paged_attention_v2(
|
||||||
out,
|
out,
|
||||||
exp_sums,
|
exp_sums,
|
||||||
max_logits,
|
max_logits,
|
||||||
|
Loading…
Reference in New Issue
Block a user