diff --git a/server/Makefile-vllm b/server/Makefile-vllm index 8e5c2210..ddb648ea 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -4,8 +4,8 @@ build-vllm-cuda: BRANCH=main build-vllm-cuda: build-vllm build-vllm-rocm: REPOSITORY=https://github.com/fxmarty/vllm-public.git -build-vllm-rocm: VLLM_COMMIT=65f4a79621b4d992cf97f6b84598804eb4ca87b6 -build-vllm-rocm: BRANCH=port-to-rocm +build-vllm-rocm: VLLM_COMMIT=ad9b7c4095ef54419a0533d254f2ad84bd2dfcae +build-vllm-rocm: BRANCH=rotary-no-positions-split-cos-sin build-vllm-rocm: build-vllm vllm: diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index ad0b20b5..4aeb447d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -223,9 +223,6 @@ class FlashLlamaAttention(torch.nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - # self.rotary_emb = PositionRotaryEmbedding.load( - # config=config, prefix=f"{prefix}.rotary_emb", weights=weights - # ) self.rotary_emb = PositionRotaryEmbedding.static( config=config, dim=self.head_size, @@ -280,9 +277,8 @@ class FlashLlamaAttention(torch.nn.Module): ) query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - - self.rotary_emb(query, cos, sin) - self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) paged_attention.reshape_and_cache( kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index ec05bc35..959949f0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -280,8 +280,7 @@ class MistralAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - self.rotary_emb(query, cos, sin) - self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) if prefill_cache_indices is not None: kv_to_cache = kv[prefill_cache_indices] diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index af4ba96b..eea5f787 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -135,8 +135,7 @@ class FlashNeoxAttention(torch.nn.Module): qkv = qkv.view(-1, 3, self.num_heads, self.head_size) # Inplace rotary - self.rotary_emb(qkv[:, 0], cos, sin) - self.rotary_emb(qkv[:, 1], cos, sin) + self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin) paged_attention.reshape_and_cache( qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 00f953a6..6a530f3c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -185,8 +185,7 @@ class FlashRWAttention(torch.nn.Module): kv = kv.view(-1, 2, self.num_heads_kv, self.head_size) # Inplace rotary - self.rotary_emb(query, cos, sin) - self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) paged_attention.reshape_and_cache( kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots @@ -301,8 +300,7 @@ class FlashRWLargeAttention(torch.nn.Module): query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size) # Inplace rotary - self.rotary_emb(query, cos, sin) - self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin) + self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) paged_attention.reshape_and_cache( kv[:, :, 0].contiguous(), diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index 9c3bfa4f..946f7683 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -425,8 +425,6 @@ class IdeficsRMSNorm(nn.Module): self.weight.data, self.variance_epsilon, ) - if res is None: - res = hidden_states if unwrap: out = out.view(*shape) @@ -613,15 +611,12 @@ class IdeficsAttention(nn.Module): position_ids.view(-1), max_s, hidden_states.dtype ) - shape = query_states.shape - query_states = self.rotary_emb( - query_states.view(-1, *shape[2:]), cos, sin - ).view(shape) - - shape = key_states.shape - key_states = self.rotary_emb( - key_states.reshape(-1, *shape[2:]), cos, sin - ).view(shape) + query_shape = query_states.shape + key_shape = key_states.shape + self.rotary_emb(query_states.view(-1, *query_shape[2:]), key_states.reshape(-1, *key_shape[2:]), cos, sin) + + query_states = query_states.view(query_shape) + key_states = key_states.view(key_shape) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 209df95c..f8bf9c7b 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -555,6 +555,8 @@ try: if IS_CUDA_SYSTEM: from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb + elif IS_ROCM_SYSTEM: + from vllm import pos_encoding_ops def _create_inv_freq(dim, base, device): inv_freq = 1.0 / ( @@ -583,32 +585,34 @@ try: self.scaling_factor = scaling_factor self.dynamic_args = None - def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + def forward(self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): # Such controlflows may add some overhead. if IS_CUDA_SYSTEM: rotary_dim = cos.shape[-1] - x1 = x[..., :rotary_dim] - x2 = x[..., rotary_dim : 2 * rotary_dim] + q1 = query[..., :rotary_dim] + q2 = query[..., rotary_dim : 2 * rotary_dim] - rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False) - return x + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + + k1 = key[..., :rotary_dim] + k2 = key[..., rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) elif IS_ROCM_SYSTEM: - # For RoCm, we fall back on a manual implementation given that Flash Attention's ROPE kernel can not be compiled for RoCm. - # We could use VLLM ROPE kernel here (compatible with RoCm), but the API is different and would require position_ids: https://github.com/vllm-project/vllm/blob/1a2bbc930135cd3b94fbff2aafbdf5c568acc8bd/csrc/pos_encoding.cpp#L3 - rotary_dim = cos.shape[-1] + # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. + # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 - dtype = x.dtype - x_upcast = x.to(torch.float32) - cos = cos.to(torch.float32) - sin = sin.to(torch.float32) + head_size = query.shape[-1] - x1 = x_upcast[..., :rotary_dim] - x2 = x_upcast[..., rotary_dim : 2 * rotary_dim] - - # Flash Attention rotary_emb kernel casts everything to float, not sure why, so we do so here as well. - x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype) - x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype) - return x + # Inplace operation, updating query and key. + pos_encoding_ops.rotary_embedding( + query, + key, + head_size, + cos, + sin, + True + ) else: raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") @@ -714,12 +718,18 @@ try: """ Return cos and sin for the asked position ids """ + if IS_ROCM_SYSTEM: + # For RoCm, we always use float cos/sin to avoid a cast. + # For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26 + # But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal. + dtype = torch.float32 self._update_cos_sin_cache(dtype, position_ids.device, max_s) cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) - return cos.unsqueeze(1), sin.unsqueeze(1) + # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. + return cos.unsqueeze(1).float(), sin.unsqueeze(1).float() class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): @@ -729,7 +739,7 @@ try: self.max_position_embeddings = max_position_embeddings self.base = base - def _update_cos_sin_cache(self, dtype, device, seqlen): + def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if (