2025-02-10 18:19:25 +00:00
|
|
|
from typing import Callable, List, Optional
|
2024-09-17 16:08:58 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
2025-02-10 18:19:25 +00:00
|
|
|
from text_generation_server.utils.kernels import load_kernel
|
2024-09-17 16:08:58 +00:00
|
|
|
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
|
|
|
|
2024-12-18 11:44:42 +00:00
|
|
|
if SYSTEM == "ipex":
|
2024-11-18 16:16:55 +00:00
|
|
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
2025-02-10 18:19:25 +00:00
|
|
|
elif SYSTEM == "cuda":
|
|
|
|
moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
|
2024-11-19 07:04:23 +00:00
|
|
|
else:
|
2025-02-10 18:19:25 +00:00
|
|
|
import moe_kernels
|
2024-09-17 16:08:58 +00:00
|
|
|
|
|
|
|
|
|
|
|
class UnquantizedSparseMoELayer(nn.Module):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
*,
|
|
|
|
n_expert_group: Optional[int],
|
|
|
|
n_experts: int,
|
|
|
|
prefix: str,
|
|
|
|
renormalize: bool,
|
|
|
|
topk: int,
|
|
|
|
topk_group: Optional[int],
|
|
|
|
weights: Weights,
|
2025-01-30 15:40:25 +00:00
|
|
|
scoring_func: Optional[str] = "softmax",
|
|
|
|
e_score_correction_bias: Optional[float] = None,
|
2024-09-17 16:08:58 +00:00
|
|
|
gate_proj_name: str = "gate_proj",
|
|
|
|
up_proj_name: str = "up_proj",
|
|
|
|
down_proj_name: str = "down_proj",
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
assert (n_expert_group is None) == (
|
|
|
|
topk_group is None
|
|
|
|
), "n_expert_group and topk_group must both be None or have some value"
|
|
|
|
|
|
|
|
self.n_expert_group = n_expert_group
|
|
|
|
self.topk = topk
|
|
|
|
self.topk_group = topk_group
|
|
|
|
self.renormalize = renormalize
|
2025-01-30 15:40:25 +00:00
|
|
|
self.weight_block_size = weights.weights_loader.weight_block_size
|
|
|
|
self.scoring_func = scoring_func
|
|
|
|
self.e_score_correction_bias = e_score_correction_bias
|
2024-09-17 16:08:58 +00:00
|
|
|
|
|
|
|
self.gate_up_proj = _load_expert_multi_weights_col(
|
|
|
|
prefix=prefix,
|
|
|
|
n_experts=n_experts,
|
|
|
|
gate_proj_name=gate_proj_name,
|
|
|
|
up_proj_name=up_proj_name,
|
|
|
|
weights=weights,
|
|
|
|
)
|
|
|
|
|
|
|
|
self.down_proj = _load_expert_weights_row(
|
|
|
|
prefix=prefix,
|
|
|
|
n_experts=n_experts,
|
|
|
|
name=down_proj_name,
|
|
|
|
weights=weights,
|
|
|
|
)
|
2024-11-18 16:16:55 +00:00
|
|
|
if SYSTEM == "ipex":
|
|
|
|
self.ipex_fused_moe = GatedMLPMOE(
|
|
|
|
W13=self.gate_up_proj, W2=self.down_proj, use_prepack=True
|
|
|
|
)
|
2024-09-17 16:08:58 +00:00
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
2025-02-10 18:19:25 +00:00
|
|
|
if SYSTEM == "rocm":
|
|
|
|
return moe_kernels.fused_moe(
|
|
|
|
x,
|
|
|
|
self.gate_up_proj,
|
|
|
|
self.down_proj,
|
|
|
|
gating_output,
|
|
|
|
self.topk,
|
|
|
|
renormalize=self.renormalize,
|
|
|
|
inplace=True,
|
|
|
|
)
|
|
|
|
elif SYSTEM == "ipex":
|
2024-11-18 16:16:55 +00:00
|
|
|
return self.ipex_fused_moe(
|
|
|
|
hidden_states=x,
|
|
|
|
router_logits=gating_output,
|
|
|
|
top_k=self.topk,
|
|
|
|
renormalize=self.renormalize,
|
|
|
|
use_grouped_topk=self.n_expert_group is not None,
|
|
|
|
num_expert_group=self.n_expert_group,
|
|
|
|
topk_group=self.topk_group,
|
2025-02-25 11:07:55 +00:00
|
|
|
scoring_func=self.scoring_func,
|
|
|
|
e_score_correction_bias=self.e_score_correction_bias,
|
2024-11-18 16:16:55 +00:00
|
|
|
)
|
2024-09-17 16:08:58 +00:00
|
|
|
return fused_moe(
|
|
|
|
x,
|
|
|
|
w1=self.gate_up_proj,
|
|
|
|
w2=self.down_proj,
|
|
|
|
gating_output=gating_output,
|
|
|
|
topk=self.topk,
|
|
|
|
renormalize=self.renormalize,
|
|
|
|
inplace=True,
|
|
|
|
use_grouped_topk=self.n_expert_group is not None,
|
|
|
|
num_expert_group=self.n_expert_group,
|
|
|
|
topk_group=self.topk_group,
|
2025-01-30 15:40:25 +00:00
|
|
|
scoring_func=self.scoring_func,
|
|
|
|
e_score_correction_bias=self.e_score_correction_bias,
|
2024-09-17 16:08:58 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _load_expert_multi_weights_col(
|
|
|
|
*,
|
|
|
|
prefix: str,
|
|
|
|
n_experts: int,
|
|
|
|
gate_proj_name: str,
|
|
|
|
up_proj_name: str,
|
|
|
|
weights: Weights,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
all_weight = None
|
|
|
|
for i in range(n_experts):
|
|
|
|
weight = weights.get_multi_weights_col(
|
|
|
|
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
|
|
|
)
|
|
|
|
|
|
|
|
assert isinstance(weight, UnquantizedWeight)
|
|
|
|
|
|
|
|
if all_weight is None:
|
|
|
|
all_weight = torch.empty(
|
|
|
|
(n_experts,) + weight.weight.shape,
|
|
|
|
dtype=weight.weight.dtype,
|
|
|
|
device=weight.weight.device,
|
|
|
|
)
|
|
|
|
|
|
|
|
all_weight[i] = weight.weight
|
|
|
|
|
|
|
|
assert all_weight is not None
|
|
|
|
|
|
|
|
return all_weight
|
|
|
|
|
|
|
|
|
|
|
|
def _load_expert_weights_row(
|
|
|
|
*,
|
|
|
|
prefix: str,
|
|
|
|
n_experts: int,
|
|
|
|
name: str,
|
|
|
|
weights: Weights,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
all_weight = None
|
|
|
|
for i in range(n_experts):
|
|
|
|
weight = weights.get_weights_row(
|
|
|
|
f"{prefix}.{i}.{name}",
|
|
|
|
)
|
|
|
|
|
|
|
|
assert isinstance(weight, UnquantizedWeight)
|
|
|
|
|
|
|
|
if all_weight is None:
|
|
|
|
all_weight = torch.empty(
|
|
|
|
(n_experts,) + weight.weight.shape,
|
|
|
|
dtype=weight.weight.dtype,
|
|
|
|
device=weight.weight.device,
|
|
|
|
)
|
|
|
|
|
|
|
|
all_weight[i] = weight.weight
|
|
|
|
|
|
|
|
assert all_weight is not None
|
|
|
|
|
|
|
|
return all_weight
|
2025-02-10 18:19:25 +00:00
|
|
|
|
|
|
|
|
|
|
|
def fused_moe(
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
w1: torch.Tensor,
|
|
|
|
w2: torch.Tensor,
|
|
|
|
gating_output: torch.Tensor,
|
|
|
|
topk: int,
|
|
|
|
renormalize: bool,
|
|
|
|
inplace: bool = False,
|
|
|
|
use_grouped_topk: bool = False,
|
|
|
|
num_expert_group: Optional[int] = None,
|
|
|
|
topk_group: Optional[int] = None,
|
|
|
|
custom_routing_function: Optional[Callable] = None,
|
|
|
|
scoring_func: str = "softmax",
|
|
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
|
|
use_fp8_w8a8: bool = False,
|
|
|
|
use_int8_w8a16: bool = False,
|
|
|
|
use_int4_w4a16: bool = False,
|
|
|
|
w1_scale: Optional[torch.Tensor] = None,
|
|
|
|
w2_scale: Optional[torch.Tensor] = None,
|
|
|
|
a1_scale: Optional[torch.Tensor] = None,
|
|
|
|
a2_scale: Optional[torch.Tensor] = None,
|
|
|
|
block_shape: Optional[List[int]] = None,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
|
|
|
weights, w1 and w2, and top-k gating mechanism.
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
|
|
|
- w1 (torch.Tensor): The first set of expert weights.
|
|
|
|
- w2 (torch.Tensor): The second set of expert weights.
|
|
|
|
- gating_output (torch.Tensor): The output of the gating operation
|
|
|
|
(before softmax).
|
|
|
|
- topk (int): The number of top-k experts to select.
|
|
|
|
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
|
|
|
- inplace (bool): If True, perform the operation in-place.
|
|
|
|
Defaults to False.
|
|
|
|
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
|
|
|
- topk_group: Optional[int]: additional parameter for grouped_topk
|
|
|
|
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
|
|
|
note: Deepseekv2 model uses grouped_topk
|
|
|
|
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
|
|
|
products for w1 and w2. Defaults to False.
|
|
|
|
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
|
|
|
|
products for w1 and w2. Defaults to False.
|
|
|
|
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
|
|
|
|
activation to compute the inner products for w1 and w2.
|
|
|
|
Defaults to False.
|
|
|
|
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
|
|
|
w1.
|
|
|
|
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
|
|
|
w2.
|
|
|
|
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
|
|
|
a1.
|
|
|
|
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
|
|
|
a2.
|
|
|
|
- block_shape: (Optional[List[int]]): Optional block size for block-wise
|
|
|
|
quantization.
|
|
|
|
Returns:
|
|
|
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
|
|
|
"""
|
|
|
|
# Check constraints.
|
|
|
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
|
|
|
|
|
|
|
if use_grouped_topk:
|
|
|
|
assert num_expert_group is not None and topk_group is not None
|
|
|
|
from loguru import logger
|
|
|
|
import inspect
|
|
|
|
|
|
|
|
logger.info(f"{inspect.signature(moe_kernels.grouped_topk)}")
|
|
|
|
topk_weights, topk_ids = moe_kernels.grouped_topk(
|
|
|
|
hidden_states,
|
|
|
|
gating_output,
|
|
|
|
topk,
|
|
|
|
renormalize,
|
|
|
|
num_expert_group,
|
|
|
|
topk_group,
|
|
|
|
scoring_func=scoring_func,
|
|
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
|
|
)
|
|
|
|
elif custom_routing_function is None:
|
|
|
|
topk_weights, topk_ids = moe_kernels.fused_topk(
|
|
|
|
hidden_states, gating_output, topk, renormalize
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
topk_weights, topk_ids = custom_routing_function(
|
|
|
|
hidden_states, gating_output, topk, renormalize
|
|
|
|
)
|
|
|
|
|
|
|
|
return moe_kernels.fused_experts(
|
|
|
|
hidden_states,
|
|
|
|
w1,
|
|
|
|
w2,
|
|
|
|
topk_weights,
|
|
|
|
topk_ids,
|
|
|
|
inplace=inplace,
|
|
|
|
use_fp8_w8a8=use_fp8_w8a8,
|
|
|
|
use_int8_w8a16=use_int8_w8a16,
|
|
|
|
use_int4_w4a16=use_int4_w4a16,
|
|
|
|
w1_scale=w1_scale,
|
|
|
|
w2_scale=w2_scale,
|
|
|
|
a1_scale=a1_scale,
|
|
|
|
a2_scale=a2_scale,
|
|
|
|
block_shape=block_shape,
|
|
|
|
)
|