mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
add moe support, fix qwen/mistral/mixtral crash
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
6bbe24d974
commit
5cd1c93cad
@ -19,7 +19,7 @@ from text_generation_server.utils.weights import (
|
||||
UnquantizedWeight,
|
||||
)
|
||||
|
||||
from .fused_moe_ipex import fused_topk, grouped_topk
|
||||
from .fused_moe import fused_topk, grouped_topk
|
||||
|
||||
# NOTE: we are using a protocol here, because multiple inherance is not nice.
|
||||
# We need `Module`, and `Module` -> some abstract class -> some concrete
|
||||
|
@ -1,11 +1,10 @@
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||
|
||||
moe_kernels = None
|
||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||
|
||||
|
||||
class UnquantizedSparseMoELayer(nn.Module):
|
||||
@ -54,21 +53,13 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.hpu_fused_moe = DynamicFusedMOE(n_experts)
|
||||
for i in range(n_experts):
|
||||
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
|
||||
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i])
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
return fused_moe(
|
||||
x,
|
||||
w1=self.gate_up_proj,
|
||||
w2=self.down_proj,
|
||||
gating_output=gating_output,
|
||||
topk=self.topk,
|
||||
renormalize=self.renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=self.n_expert_group is not None,
|
||||
num_expert_group=self.n_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
)
|
||||
return self.hpu_fused_moe(x, gating_output, self.topk)
|
||||
|
||||
|
||||
def _load_expert_multi_weights_col(
|
||||
@ -128,110 +119,3 @@ def _load_expert_weights_row(
|
||||
assert all_weight is not None
|
||||
|
||||
return all_weight
|
||||
|
||||
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||
note: Deepseekv2 model uses grouped_topk
|
||||
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
|
||||
activation to compute the inner products for w1 and w2.
|
||||
Defaults to False.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w2.
|
||||
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
a1.
|
||||
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
a2.
|
||||
- block_shape: (Optional[List[int]]): Optional block size for block-wise
|
||||
quantization.
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Check constraints.
|
||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||
|
||||
if use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
from loguru import logger
|
||||
import inspect
|
||||
|
||||
logger.info(f"{inspect.signature(moe_kernels.grouped_topk)}")
|
||||
topk_weights, topk_ids = moe_kernels.grouped_topk(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
topk,
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids = moe_kernels.fused_topk(
|
||||
hidden_states, gating_output, topk, renormalize
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states, gating_output, topk, renormalize
|
||||
)
|
||||
|
||||
return moe_kernels.fused_experts(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=inplace,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
@ -263,7 +263,7 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
# sdpa
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
key=key,
|
||||
|
@ -345,7 +345,7 @@ class DbrxAttention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
# sdpa
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
key=kv[:, 0],
|
||||
|
@ -268,7 +268,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
# sdpa
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
key=kv[:, 0],
|
||||
|
@ -240,7 +240,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
# sdpa
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
key=kv[:, 0],
|
||||
|
@ -242,7 +242,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
# sdpa
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
key=key,
|
||||
|
@ -193,7 +193,7 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
# sdpa
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
key=key,
|
||||
|
@ -235,7 +235,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
# sdpa
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
key=kv[:, 0],
|
||||
@ -652,6 +652,11 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
cross_attention_states=None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if prefill_cache_indices is not None and slots.size(
|
||||
0
|
||||
) != prefill_cache_indices.size(0):
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
inputs_embeds,
|
||||
|
@ -212,11 +212,11 @@ class MistralAttention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
# sdpa
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
@ -488,10 +488,6 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||
) != prefill_cache_indices.size(0):
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
|
@ -268,11 +268,11 @@ class MixtralAttention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
# sdpa
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
@ -523,10 +523,6 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
) != prefill_cache_indices.size(0):
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
@ -180,7 +180,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
# sdpa
|
||||
attn_output = attention(
|
||||
query=qkv[:, 0],
|
||||
key=qkv[:, 1],
|
||||
|
@ -139,11 +139,11 @@ class Qwen2Attention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
# sdpa
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
@ -378,10 +378,6 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||
) != slots.size(0):
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
|
@ -217,7 +217,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
# sdpa
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
key=kv[:, 0],
|
||||
|
@ -300,7 +300,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
# sdpa
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
key=key_value[:, 0],
|
||||
|
@ -269,11 +269,11 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
# sdpa
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
@ -602,10 +602,6 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||
) != prefill_cache_indices.size(0):
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
Loading…
Reference in New Issue
Block a user