From 5cd1c93cad96fa0c00deb7be26becfce2854084b Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 18 Mar 2025 00:45:15 -0700 Subject: [PATCH] add moe support, fix qwen/mistral/mixtral crash Signed-off-by: Wang, Yi A --- .../layers/moe/__init__.py | 2 +- .../moe/{fused_moe_ipex.py => fused_moe.py} | 0 .../layers/moe/unquantized.py | 132 ++---------------- .../custom_modeling/flash_cohere_modeling.py | 2 +- .../custom_modeling/flash_dbrx_modeling.py | 2 +- .../custom_modeling/flash_gemma2_modeling.py | 2 +- .../custom_modeling/flash_gemma_modeling.py | 2 +- .../custom_modeling/flash_gpt2_modeling.py | 2 +- .../custom_modeling/flash_gptj_modeling.py | 2 +- .../custom_modeling/flash_llama_modeling.py | 7 +- .../custom_modeling/flash_mistral_modeling.py | 10 +- .../custom_modeling/flash_mixtral_modeling.py | 10 +- .../custom_modeling/flash_neox_modeling.py | 2 +- .../custom_modeling/flash_qwen2_modeling.py | 10 +- .../custom_modeling/flash_rw_modeling.py | 2 +- .../flash_santacoder_modeling.py | 2 +- .../flash_starcoder2_modeling.py | 10 +- 17 files changed, 36 insertions(+), 163 deletions(-) rename backends/gaudi/server/text_generation_server/layers/moe/{fused_moe_ipex.py => fused_moe.py} (100%) diff --git a/backends/gaudi/server/text_generation_server/layers/moe/__init__.py b/backends/gaudi/server/text_generation_server/layers/moe/__init__.py index cba81407..8b9d6fcb 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/__init__.py @@ -19,7 +19,7 @@ from text_generation_server.utils.weights import ( UnquantizedWeight, ) -from .fused_moe_ipex import fused_topk, grouped_topk +from .fused_moe import fused_topk, grouped_topk # NOTE: we are using a protocol here, because multiple inherance is not nice. # We need `Module`, and `Module` -> some abstract class -> some concrete diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fused_moe_ipex.py b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py similarity index 100% rename from backends/gaudi/server/text_generation_server/layers/moe/fused_moe_ipex.py rename to backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py diff --git a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py index 8cb27879..ec158398 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py @@ -1,11 +1,10 @@ -from typing import Callable, List, Optional +from typing import Optional import torch import torch.nn as nn from text_generation_server.utils.weights import UnquantizedWeight, Weights - -moe_kernels = None +from vllm_hpu_extension.ops import DynamicFusedMOE class UnquantizedSparseMoELayer(nn.Module): @@ -54,21 +53,13 @@ class UnquantizedSparseMoELayer(nn.Module): weights=weights, ) + self.hpu_fused_moe = DynamicFusedMOE(n_experts) + for i in range(n_experts): + self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i]) + self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i]) + def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: - return fused_moe( - x, - w1=self.gate_up_proj, - w2=self.down_proj, - gating_output=gating_output, - topk=self.topk, - renormalize=self.renormalize, - inplace=True, - use_grouped_topk=self.n_expert_group is not None, - num_expert_group=self.n_expert_group, - topk_group=self.topk_group, - scoring_func=self.scoring_func, - e_score_correction_bias=self.e_score_correction_bias, - ) + return self.hpu_fused_moe(x, gating_output, self.topk) def _load_expert_multi_weights_col( @@ -128,110 +119,3 @@ def _load_expert_weights_row( assert all_weight is not None return all_weight - - -def fused_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - inplace: bool = False, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - use_fp8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, -) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - num_expert_group: Optional[int]: additional parameter for grouped_topk - - topk_group: Optional[int]: additional parameter for grouped_topk - - use_grouped_topk: If True, use grouped_topk instead of fused_topk - note: Deepseekv2 model uses grouped_topk - - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 - activation to compute the inner products for w1 and w2. - Defaults to False. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - - a1_scale (Optional[torch.Tensor]): Optional scale to be used for - a1. - - a2_scale (Optional[torch.Tensor]): Optional scale to be used for - a2. - - block_shape: (Optional[List[int]]): Optional block size for block-wise - quantization. - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - - if use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - from loguru import logger - import inspect - - logger.info(f"{inspect.signature(moe_kernels.grouped_topk)}") - topk_weights, topk_ids = moe_kernels.grouped_topk( - hidden_states, - gating_output, - topk, - renormalize, - num_expert_group, - topk_group, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) - elif custom_routing_function is None: - topk_weights, topk_ids = moe_kernels.fused_topk( - hidden_states, gating_output, topk, renormalize - ) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states, gating_output, topk, renormalize - ) - - return moe_kernels.fused_experts( - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - inplace=inplace, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_shape, - ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 8d32032d..77dec80d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -263,7 +263,7 @@ class FlashCohereAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, key=key, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index c01bd1bc..0f1338ca 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -345,7 +345,7 @@ class DbrxAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, key=kv[:, 0], diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 5b7adad1..632e8017 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -268,7 +268,7 @@ class FlashGemma2Attention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, key=kv[:, 0], diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index d26184b6..d832fb00 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -240,7 +240,7 @@ class FlashGemmaAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, key=kv[:, 0], diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index a6e0a7de..80236fe8 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -242,7 +242,7 @@ class FlashGPT2Attention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, key=key, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index 9229a453..3135acde 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -193,7 +193,7 @@ class FlashGPTJAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, key=key, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 857e1757..a0c4fb8c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -235,7 +235,7 @@ class FlashLlamaAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, key=kv[:, 0], @@ -652,6 +652,11 @@ class FlashLlamaForCausalLM(torch.nn.Module): adapter_data: Optional[torch.Tensor] = None, cross_attention_states=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if prefill_cache_indices is not None and slots.size( + 0 + ) != prefill_cache_indices.size(0): + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( inputs_embeds, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 8214b6b7..38eba082 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -212,11 +212,11 @@ class MistralAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, @@ -488,10 +488,6 @@ class FlashMistralForCausalLM(torch.nn.Module): ) != prefill_cache_indices.size(0): # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 18ffe060..fbcb0970 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -268,11 +268,11 @@ class MixtralAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, @@ -523,10 +523,6 @@ class FlashMixtralForCausalLM(torch.nn.Module): ) != prefill_cache_indices.size(0): # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 76269f22..d1904c03 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -180,7 +180,7 @@ class FlashNeoxAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=qkv[:, 0], key=qkv[:, 1], diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index c62435fe..480a17d1 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -139,11 +139,11 @@ class Qwen2Attention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, @@ -378,10 +378,6 @@ class Qwen2ForCausalLM(torch.nn.Module): ) != slots.size(0): # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) inputs_embeds = self.embed_tokens(input_ids) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index c6034bf0..e7c4b2b6 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -217,7 +217,7 @@ class FlashRWAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, key=kv[:, 0], diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 9b24e8ba..57d4ee64 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -300,7 +300,7 @@ class FlashMQAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, key=key_value[:, 0], diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index d12bee5c..082e5d82 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -269,11 +269,11 @@ class Starcoder2Attention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # flash attention + # sdpa attn_output = attention( query=query, - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, @@ -602,10 +602,6 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): ) != prefill_cache_indices.size(0): # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids,