text-generation-inference/server/text_generation_server/layers/moe/unquantized.py

271 lines
8.8 KiB
Python
Raw Normal View History

from typing import Callable, List, Optional
import torch
import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import UnquantizedWeight, Weights
if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
elif SYSTEM == "cuda":
moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
2024-11-19 07:04:23 +00:00
else:
import moe_kernels
class UnquantizedSparseMoELayer(nn.Module):
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
weights=weights,
)
self.down_proj = _load_expert_weights_row(
prefix=prefix,
n_experts=n_experts,
name=down_proj_name,
weights=weights,
)
if SYSTEM == "ipex":
self.ipex_fused_moe = GatedMLPMOE(
W13=self.gate_up_proj, W2=self.down_proj, use_prepack=True
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
if SYSTEM == "rocm":
return moe_kernels.fused_moe(
x,
self.gate_up_proj,
self.down_proj,
gating_output,
self.topk,
renormalize=self.renormalize,
inplace=True,
)
elif SYSTEM == "ipex":
return self.ipex_fused_moe(
hidden_states=x,
router_logits=gating_output,
top_k=self.topk,
renormalize=self.renormalize,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
)
return fused_moe(
x,
w1=self.gate_up_proj,
w2=self.down_proj,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
inplace=True,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
)
def _load_expert_multi_weights_col(
*,
prefix: str,
n_experts: int,
gate_proj_name: str,
up_proj_name: str,
weights: Weights,
) -> torch.Tensor:
all_weight = None
for i in range(n_experts):
weight = weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
assert isinstance(weight, UnquantizedWeight)
if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=weight.weight.dtype,
device=weight.weight.device,
)
all_weight[i] = weight.weight
assert all_weight is not None
return all_weight
def _load_expert_weights_row(
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> torch.Tensor:
all_weight = None
for i in range(n_experts):
weight = weights.get_weights_row(
f"{prefix}.{i}.{name}",
)
assert isinstance(weight, UnquantizedWeight)
if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=weight.weight.dtype,
device=weight.weight.device,
)
all_weight[i] = weight.weight
assert all_weight is not None
return all_weight
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
from loguru import logger
import inspect
logger.info(f"{inspect.signature(moe_kernels.grouped_topk)}")
topk_weights, topk_ids = moe_kernels.grouped_topk(
hidden_states,
gating_output,
topk,
renormalize,
num_expert_group,
topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
elif custom_routing_function is None:
topk_weights, topk_ids = moe_kernels.fused_topk(
hidden_states, gating_output, topk, renormalize
)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize
)
return moe_kernels.fused_experts(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace=inplace,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)