from typing import Callable, List, Optional import torch import torch.nn as nn from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.weights import UnquantizedWeight, Weights if SYSTEM == "ipex": from intel_extension_for_pytorch.llm.modules import GatedMLPMOE elif SYSTEM == "cuda": moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe") else: import moe_kernels class UnquantizedSparseMoELayer(nn.Module): def __init__( self, *, n_expert_group: Optional[int], n_experts: int, prefix: str, renormalize: bool, topk: int, topk_group: Optional[int], weights: Weights, scoring_func: Optional[str] = "softmax", e_score_correction_bias: Optional[float] = None, gate_proj_name: str = "gate_proj", up_proj_name: str = "up_proj", down_proj_name: str = "down_proj", ): super().__init__() assert (n_expert_group is None) == ( topk_group is None ), "n_expert_group and topk_group must both be None or have some value" self.n_expert_group = n_expert_group self.topk = topk self.topk_group = topk_group self.renormalize = renormalize self.weight_block_size = weights.weights_loader.weight_block_size self.scoring_func = scoring_func self.e_score_correction_bias = e_score_correction_bias self.gate_up_proj = _load_expert_multi_weights_col( prefix=prefix, n_experts=n_experts, gate_proj_name=gate_proj_name, up_proj_name=up_proj_name, weights=weights, ) self.down_proj = _load_expert_weights_row( prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights, ) if SYSTEM == "ipex": self.ipex_fused_moe = GatedMLPMOE( W13=self.gate_up_proj, W2=self.down_proj, use_prepack=True ) def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: if SYSTEM == "rocm": return moe_kernels.fused_moe( x, self.gate_up_proj, self.down_proj, gating_output, self.topk, renormalize=self.renormalize, inplace=True, ) elif SYSTEM == "ipex": return self.ipex_fused_moe( hidden_states=x, router_logits=gating_output, top_k=self.topk, renormalize=self.renormalize, use_grouped_topk=self.n_expert_group is not None, num_expert_group=self.n_expert_group, topk_group=self.topk_group, scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, ) return fused_moe( x, w1=self.gate_up_proj, w2=self.down_proj, gating_output=gating_output, topk=self.topk, renormalize=self.renormalize, inplace=True, use_grouped_topk=self.n_expert_group is not None, num_expert_group=self.n_expert_group, topk_group=self.topk_group, scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, ) def _load_expert_multi_weights_col( *, prefix: str, n_experts: int, gate_proj_name: str, up_proj_name: str, weights: Weights, ) -> torch.Tensor: all_weight = None for i in range(n_experts): weight = weights.get_multi_weights_col( [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 ) assert isinstance(weight, UnquantizedWeight) if all_weight is None: all_weight = torch.empty( (n_experts,) + weight.weight.shape, dtype=weight.weight.dtype, device=weight.weight.device, ) all_weight[i] = weight.weight assert all_weight is not None return all_weight def _load_expert_weights_row( *, prefix: str, n_experts: int, name: str, weights: Weights, ) -> torch.Tensor: all_weight = None for i in range(n_experts): weight = weights.get_weights_row( f"{prefix}.{i}.{name}", ) assert isinstance(weight, UnquantizedWeight) if all_weight is None: all_weight = torch.empty( (n_experts,) + weight.weight.shape, dtype=weight.weight.dtype, device=weight.weight.device, ) all_weight[i] = weight.weight assert all_weight is not None return all_weight def fused_moe( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, inplace: bool = False, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. Parameters: - hidden_states (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - gating_output (torch.Tensor): The output of the gating operation (before softmax). - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - inplace (bool): If True, perform the operation in-place. Defaults to False. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 activation to compute the inner products for w1 and w2. Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. - a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1. - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. - block_shape: (Optional[List[int]]): Optional block size for block-wise quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ # Check constraints. assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" if use_grouped_topk: assert num_expert_group is not None and topk_group is not None from loguru import logger import inspect logger.info(f"{inspect.signature(moe_kernels.grouped_topk)}") topk_weights, topk_ids = moe_kernels.grouped_topk( hidden_states, gating_output, topk, renormalize, num_expert_group, topk_group, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, ) elif custom_routing_function is None: topk_weights, topk_ids = moe_kernels.fused_topk( hidden_states, gating_output, topk, renormalize ) else: topk_weights, topk_ids = custom_routing_function( hidden_states, gating_output, topk, renormalize ) return moe_kernels.fused_experts( hidden_states, w1, w2, topk_weights, topk_ids, inplace=inplace, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_shape, )