mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
use modified vllm rope kernel when using rocm
This commit is contained in:
parent
52c0e0e53b
commit
736e199737
@ -4,8 +4,8 @@ build-vllm-cuda: BRANCH=main
|
|||||||
build-vllm-cuda: build-vllm
|
build-vllm-cuda: build-vllm
|
||||||
|
|
||||||
build-vllm-rocm: REPOSITORY=https://github.com/fxmarty/vllm-public.git
|
build-vllm-rocm: REPOSITORY=https://github.com/fxmarty/vllm-public.git
|
||||||
build-vllm-rocm: VLLM_COMMIT=65f4a79621b4d992cf97f6b84598804eb4ca87b6
|
build-vllm-rocm: VLLM_COMMIT=ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
|
||||||
build-vllm-rocm: BRANCH=port-to-rocm
|
build-vllm-rocm: BRANCH=rotary-no-positions-split-cos-sin
|
||||||
build-vllm-rocm: build-vllm
|
build-vllm-rocm: build-vllm
|
||||||
|
|
||||||
vllm:
|
vllm:
|
||||||
|
@ -223,9 +223,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
# self.rotary_emb = PositionRotaryEmbedding.load(
|
|
||||||
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
|
||||||
# )
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config,
|
config=config,
|
||||||
dim=self.head_size,
|
dim=self.head_size,
|
||||||
@ -280,9 +277,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
|
@ -280,8 +280,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
|
||||||
|
|
||||||
if prefill_cache_indices is not None:
|
if prefill_cache_indices is not None:
|
||||||
kv_to_cache = kv[prefill_cache_indices]
|
kv_to_cache = kv[prefill_cache_indices]
|
||||||
|
@ -135,8 +135,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin)
|
||||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
||||||
|
@ -185,8 +185,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
kv = kv.view(-1, 2, self.num_heads_kv, self.head_size)
|
kv = kv.view(-1, 2, self.num_heads_kv, self.head_size)
|
||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
@ -301,8 +300,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv[:, :, 0].contiguous(),
|
kv[:, :, 0].contiguous(),
|
||||||
|
@ -425,8 +425,6 @@ class IdeficsRMSNorm(nn.Module):
|
|||||||
self.weight.data,
|
self.weight.data,
|
||||||
self.variance_epsilon,
|
self.variance_epsilon,
|
||||||
)
|
)
|
||||||
if res is None:
|
|
||||||
res = hidden_states
|
|
||||||
|
|
||||||
if unwrap:
|
if unwrap:
|
||||||
out = out.view(*shape)
|
out = out.view(*shape)
|
||||||
@ -613,15 +611,12 @@ class IdeficsAttention(nn.Module):
|
|||||||
position_ids.view(-1), max_s, hidden_states.dtype
|
position_ids.view(-1), max_s, hidden_states.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
shape = query_states.shape
|
query_shape = query_states.shape
|
||||||
query_states = self.rotary_emb(
|
key_shape = key_states.shape
|
||||||
query_states.view(-1, *shape[2:]), cos, sin
|
self.rotary_emb(query_states.view(-1, *query_shape[2:]), key_states.reshape(-1, *key_shape[2:]), cos, sin)
|
||||||
).view(shape)
|
|
||||||
|
query_states = query_states.view(query_shape)
|
||||||
shape = key_states.shape
|
key_states = key_states.view(key_shape)
|
||||||
key_states = self.rotary_emb(
|
|
||||||
key_states.reshape(-1, *shape[2:]), cos, sin
|
|
||||||
).view(shape)
|
|
||||||
|
|
||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
key_states = key_states.transpose(1, 2)
|
key_states = key_states.transpose(1, 2)
|
||||||
|
@ -555,6 +555,8 @@ try:
|
|||||||
if IS_CUDA_SYSTEM:
|
if IS_CUDA_SYSTEM:
|
||||||
from flash_attn.layers.rotary import RotaryEmbedding
|
from flash_attn.layers.rotary import RotaryEmbedding
|
||||||
import rotary_emb
|
import rotary_emb
|
||||||
|
elif IS_ROCM_SYSTEM:
|
||||||
|
from vllm import pos_encoding_ops
|
||||||
|
|
||||||
def _create_inv_freq(dim, base, device):
|
def _create_inv_freq(dim, base, device):
|
||||||
inv_freq = 1.0 / (
|
inv_freq = 1.0 / (
|
||||||
@ -583,32 +585,34 @@ try:
|
|||||||
self.scaling_factor = scaling_factor
|
self.scaling_factor = scaling_factor
|
||||||
self.dynamic_args = None
|
self.dynamic_args = None
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
def forward(self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||||
# Such controlflows may add some overhead.
|
# Such controlflows may add some overhead.
|
||||||
if IS_CUDA_SYSTEM:
|
if IS_CUDA_SYSTEM:
|
||||||
rotary_dim = cos.shape[-1]
|
rotary_dim = cos.shape[-1]
|
||||||
x1 = x[..., :rotary_dim]
|
q1 = query[..., :rotary_dim]
|
||||||
x2 = x[..., rotary_dim : 2 * rotary_dim]
|
q2 = query[..., rotary_dim : 2 * rotary_dim]
|
||||||
|
|
||||||
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
||||||
return x
|
|
||||||
|
k1 = key[..., :rotary_dim]
|
||||||
|
k2 = key[..., rotary_dim : 2 * rotary_dim]
|
||||||
|
|
||||||
|
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||||
elif IS_ROCM_SYSTEM:
|
elif IS_ROCM_SYSTEM:
|
||||||
# For RoCm, we fall back on a manual implementation given that Flash Attention's ROPE kernel can not be compiled for RoCm.
|
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||||
# We could use VLLM ROPE kernel here (compatible with RoCm), but the API is different and would require position_ids: https://github.com/vllm-project/vllm/blob/1a2bbc930135cd3b94fbff2aafbdf5c568acc8bd/csrc/pos_encoding.cpp#L3
|
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
||||||
rotary_dim = cos.shape[-1]
|
|
||||||
|
|
||||||
dtype = x.dtype
|
head_size = query.shape[-1]
|
||||||
x_upcast = x.to(torch.float32)
|
|
||||||
cos = cos.to(torch.float32)
|
|
||||||
sin = sin.to(torch.float32)
|
|
||||||
|
|
||||||
x1 = x_upcast[..., :rotary_dim]
|
# Inplace operation, updating query and key.
|
||||||
x2 = x_upcast[..., rotary_dim : 2 * rotary_dim]
|
pos_encoding_ops.rotary_embedding(
|
||||||
|
query,
|
||||||
# Flash Attention rotary_emb kernel casts everything to float, not sure why, so we do so here as well.
|
key,
|
||||||
x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype)
|
head_size,
|
||||||
x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype)
|
cos,
|
||||||
return x
|
sin,
|
||||||
|
True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
||||||
|
|
||||||
@ -714,12 +718,18 @@ try:
|
|||||||
"""
|
"""
|
||||||
Return cos and sin for the asked position ids
|
Return cos and sin for the asked position ids
|
||||||
"""
|
"""
|
||||||
|
if IS_ROCM_SYSTEM:
|
||||||
|
# For RoCm, we always use float cos/sin to avoid a cast.
|
||||||
|
# For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
|
||||||
|
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
||||||
|
|
||||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||||
|
return cos.unsqueeze(1).float(), sin.unsqueeze(1).float()
|
||||||
|
|
||||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||||
@ -729,7 +739,7 @@ try:
|
|||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.base = base
|
self.base = base
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
# or if we're on a new device (possibly due to tracing for instance)
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
if (
|
if (
|
||||||
|
Loading…
Reference in New Issue
Block a user