| 
									
										
										
										
											2025-02-10 18:19:25 +00:00
										 |  |  | from typing import Callable, List, Optional | 
					
						
							| 
									
										
										
										
											2024-09-17 16:08:58 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | import torch.nn as nn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from text_generation_server.utils.import_utils import SYSTEM | 
					
						
							| 
									
										
										
										
											2025-02-10 18:19:25 +00:00
										 |  |  | from text_generation_server.utils.kernels import load_kernel | 
					
						
							| 
									
										
										
										
											2024-09-17 16:08:58 +00:00
										 |  |  | from text_generation_server.utils.weights import UnquantizedWeight, Weights | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-18 11:44:42 +00:00
										 |  |  | if SYSTEM == "ipex": | 
					
						
							| 
									
										
										
										
											2024-11-18 16:16:55 +00:00
										 |  |  |     from intel_extension_for_pytorch.llm.modules import GatedMLPMOE | 
					
						
							| 
									
										
										
										
											2025-02-10 18:19:25 +00:00
										 |  |  | elif SYSTEM == "cuda": | 
					
						
							|  |  |  |     moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe") | 
					
						
							| 
									
										
										
										
											2024-11-19 07:04:23 +00:00
										 |  |  | else: | 
					
						
							| 
									
										
										
										
											2025-02-10 18:19:25 +00:00
										 |  |  |     import moe_kernels | 
					
						
							| 
									
										
										
										
											2024-09-17 16:08:58 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class UnquantizedSparseMoELayer(nn.Module): | 
					
						
							|  |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         *, | 
					
						
							|  |  |  |         n_expert_group: Optional[int], | 
					
						
							|  |  |  |         n_experts: int, | 
					
						
							|  |  |  |         prefix: str, | 
					
						
							|  |  |  |         renormalize: bool, | 
					
						
							|  |  |  |         topk: int, | 
					
						
							|  |  |  |         topk_group: Optional[int], | 
					
						
							|  |  |  |         weights: Weights, | 
					
						
							| 
									
										
										
										
											2025-01-30 15:40:25 +00:00
										 |  |  |         scoring_func: Optional[str] = "softmax", | 
					
						
							|  |  |  |         e_score_correction_bias: Optional[float] = None, | 
					
						
							| 
									
										
										
										
											2024-09-17 16:08:58 +00:00
										 |  |  |         gate_proj_name: str = "gate_proj", | 
					
						
							|  |  |  |         up_proj_name: str = "up_proj", | 
					
						
							|  |  |  |         down_proj_name: str = "down_proj", | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         assert (n_expert_group is None) == ( | 
					
						
							|  |  |  |             topk_group is None | 
					
						
							|  |  |  |         ), "n_expert_group and topk_group must both be None or have some value" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.n_expert_group = n_expert_group | 
					
						
							|  |  |  |         self.topk = topk | 
					
						
							|  |  |  |         self.topk_group = topk_group | 
					
						
							|  |  |  |         self.renormalize = renormalize | 
					
						
							| 
									
										
										
										
											2025-01-30 15:40:25 +00:00
										 |  |  |         self.weight_block_size = weights.weights_loader.weight_block_size | 
					
						
							|  |  |  |         self.scoring_func = scoring_func | 
					
						
							|  |  |  |         self.e_score_correction_bias = e_score_correction_bias | 
					
						
							| 
									
										
										
										
											2024-09-17 16:08:58 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         self.gate_up_proj = _load_expert_multi_weights_col( | 
					
						
							|  |  |  |             prefix=prefix, | 
					
						
							|  |  |  |             n_experts=n_experts, | 
					
						
							|  |  |  |             gate_proj_name=gate_proj_name, | 
					
						
							|  |  |  |             up_proj_name=up_proj_name, | 
					
						
							|  |  |  |             weights=weights, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.down_proj = _load_expert_weights_row( | 
					
						
							|  |  |  |             prefix=prefix, | 
					
						
							|  |  |  |             n_experts=n_experts, | 
					
						
							|  |  |  |             name=down_proj_name, | 
					
						
							|  |  |  |             weights=weights, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-11-18 16:16:55 +00:00
										 |  |  |         if SYSTEM == "ipex": | 
					
						
							|  |  |  |             self.ipex_fused_moe = GatedMLPMOE( | 
					
						
							|  |  |  |                 W13=self.gate_up_proj, W2=self.down_proj, use_prepack=True | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-09-17 16:08:58 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: | 
					
						
							| 
									
										
										
										
											2025-02-10 18:19:25 +00:00
										 |  |  |         if SYSTEM == "rocm": | 
					
						
							|  |  |  |             return moe_kernels.fused_moe( | 
					
						
							|  |  |  |                 x, | 
					
						
							|  |  |  |                 self.gate_up_proj, | 
					
						
							|  |  |  |                 self.down_proj, | 
					
						
							|  |  |  |                 gating_output, | 
					
						
							|  |  |  |                 self.topk, | 
					
						
							|  |  |  |                 renormalize=self.renormalize, | 
					
						
							|  |  |  |                 inplace=True, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         elif SYSTEM == "ipex": | 
					
						
							| 
									
										
										
										
											2024-11-18 16:16:55 +00:00
										 |  |  |             return self.ipex_fused_moe( | 
					
						
							|  |  |  |                 hidden_states=x, | 
					
						
							|  |  |  |                 router_logits=gating_output, | 
					
						
							|  |  |  |                 top_k=self.topk, | 
					
						
							|  |  |  |                 renormalize=self.renormalize, | 
					
						
							|  |  |  |                 use_grouped_topk=self.n_expert_group is not None, | 
					
						
							|  |  |  |                 num_expert_group=self.n_expert_group, | 
					
						
							|  |  |  |                 topk_group=self.topk_group, | 
					
						
							| 
									
										
										
										
											2025-02-25 11:07:55 +00:00
										 |  |  |                 scoring_func=self.scoring_func, | 
					
						
							|  |  |  |                 e_score_correction_bias=self.e_score_correction_bias, | 
					
						
							| 
									
										
										
										
											2024-11-18 16:16:55 +00:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-09-17 16:08:58 +00:00
										 |  |  |         return fused_moe( | 
					
						
							|  |  |  |             x, | 
					
						
							|  |  |  |             w1=self.gate_up_proj, | 
					
						
							|  |  |  |             w2=self.down_proj, | 
					
						
							|  |  |  |             gating_output=gating_output, | 
					
						
							|  |  |  |             topk=self.topk, | 
					
						
							|  |  |  |             renormalize=self.renormalize, | 
					
						
							|  |  |  |             inplace=True, | 
					
						
							|  |  |  |             use_grouped_topk=self.n_expert_group is not None, | 
					
						
							|  |  |  |             num_expert_group=self.n_expert_group, | 
					
						
							|  |  |  |             topk_group=self.topk_group, | 
					
						
							| 
									
										
										
										
											2025-01-30 15:40:25 +00:00
										 |  |  |             scoring_func=self.scoring_func, | 
					
						
							|  |  |  |             e_score_correction_bias=self.e_score_correction_bias, | 
					
						
							| 
									
										
										
										
											2024-09-17 16:08:58 +00:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _load_expert_multi_weights_col( | 
					
						
							|  |  |  |     *, | 
					
						
							|  |  |  |     prefix: str, | 
					
						
							|  |  |  |     n_experts: int, | 
					
						
							|  |  |  |     gate_proj_name: str, | 
					
						
							|  |  |  |     up_proj_name: str, | 
					
						
							|  |  |  |     weights: Weights, | 
					
						
							|  |  |  | ) -> torch.Tensor: | 
					
						
							|  |  |  |     all_weight = None | 
					
						
							|  |  |  |     for i in range(n_experts): | 
					
						
							|  |  |  |         weight = weights.get_multi_weights_col( | 
					
						
							|  |  |  |             [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         assert isinstance(weight, UnquantizedWeight) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if all_weight is None: | 
					
						
							|  |  |  |             all_weight = torch.empty( | 
					
						
							|  |  |  |                 (n_experts,) + weight.weight.shape, | 
					
						
							|  |  |  |                 dtype=weight.weight.dtype, | 
					
						
							|  |  |  |                 device=weight.weight.device, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         all_weight[i] = weight.weight | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert all_weight is not None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return all_weight | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _load_expert_weights_row( | 
					
						
							|  |  |  |     *, | 
					
						
							|  |  |  |     prefix: str, | 
					
						
							|  |  |  |     n_experts: int, | 
					
						
							|  |  |  |     name: str, | 
					
						
							|  |  |  |     weights: Weights, | 
					
						
							|  |  |  | ) -> torch.Tensor: | 
					
						
							|  |  |  |     all_weight = None | 
					
						
							|  |  |  |     for i in range(n_experts): | 
					
						
							|  |  |  |         weight = weights.get_weights_row( | 
					
						
							|  |  |  |             f"{prefix}.{i}.{name}", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         assert isinstance(weight, UnquantizedWeight) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if all_weight is None: | 
					
						
							|  |  |  |             all_weight = torch.empty( | 
					
						
							|  |  |  |                 (n_experts,) + weight.weight.shape, | 
					
						
							|  |  |  |                 dtype=weight.weight.dtype, | 
					
						
							|  |  |  |                 device=weight.weight.device, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         all_weight[i] = weight.weight | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert all_weight is not None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return all_weight | 
					
						
							| 
									
										
										
										
											2025-02-10 18:19:25 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def fused_moe( | 
					
						
							|  |  |  |     hidden_states: torch.Tensor, | 
					
						
							|  |  |  |     w1: torch.Tensor, | 
					
						
							|  |  |  |     w2: torch.Tensor, | 
					
						
							|  |  |  |     gating_output: torch.Tensor, | 
					
						
							|  |  |  |     topk: int, | 
					
						
							|  |  |  |     renormalize: bool, | 
					
						
							|  |  |  |     inplace: bool = False, | 
					
						
							|  |  |  |     use_grouped_topk: bool = False, | 
					
						
							|  |  |  |     num_expert_group: Optional[int] = None, | 
					
						
							|  |  |  |     topk_group: Optional[int] = None, | 
					
						
							|  |  |  |     custom_routing_function: Optional[Callable] = None, | 
					
						
							|  |  |  |     scoring_func: str = "softmax", | 
					
						
							|  |  |  |     e_score_correction_bias: Optional[torch.Tensor] = None, | 
					
						
							|  |  |  |     use_fp8_w8a8: bool = False, | 
					
						
							|  |  |  |     use_int8_w8a16: bool = False, | 
					
						
							|  |  |  |     use_int4_w4a16: bool = False, | 
					
						
							|  |  |  |     w1_scale: Optional[torch.Tensor] = None, | 
					
						
							|  |  |  |     w2_scale: Optional[torch.Tensor] = None, | 
					
						
							|  |  |  |     a1_scale: Optional[torch.Tensor] = None, | 
					
						
							|  |  |  |     a2_scale: Optional[torch.Tensor] = None, | 
					
						
							|  |  |  |     block_shape: Optional[List[int]] = None, | 
					
						
							|  |  |  | ) -> torch.Tensor: | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     This function computes a Mixture of Experts (MoE) layer using two sets of | 
					
						
							|  |  |  |     weights, w1 and w2, and top-k gating mechanism. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Parameters: | 
					
						
							|  |  |  |     - hidden_states (torch.Tensor): The input tensor to the MoE layer. | 
					
						
							|  |  |  |     - w1 (torch.Tensor): The first set of expert weights. | 
					
						
							|  |  |  |     - w2 (torch.Tensor): The second set of expert weights. | 
					
						
							|  |  |  |     - gating_output (torch.Tensor): The output of the gating operation | 
					
						
							|  |  |  |         (before softmax). | 
					
						
							|  |  |  |     - topk (int): The number of top-k experts to select. | 
					
						
							|  |  |  |     - renormalize (bool): If True, renormalize the top-k weights to sum to 1. | 
					
						
							|  |  |  |     - inplace (bool): If True, perform the operation in-place. | 
					
						
							|  |  |  |         Defaults to False. | 
					
						
							|  |  |  |     - num_expert_group: Optional[int]: additional parameter for grouped_topk | 
					
						
							|  |  |  |     - topk_group: Optional[int]: additional parameter for grouped_topk | 
					
						
							|  |  |  |     - use_grouped_topk: If True, use grouped_topk instead of fused_topk | 
					
						
							|  |  |  |         note: Deepseekv2 model uses grouped_topk | 
					
						
							|  |  |  |     - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner | 
					
						
							|  |  |  |         products for w1 and w2. Defaults to False. | 
					
						
							|  |  |  |     - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner | 
					
						
							|  |  |  |         products for w1 and w2. Defaults to False. | 
					
						
							|  |  |  |     - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 | 
					
						
							|  |  |  |         activation to compute the inner products for w1 and w2. | 
					
						
							|  |  |  |         Defaults to False. | 
					
						
							|  |  |  |     - w1_scale (Optional[torch.Tensor]): Optional scale to be used for | 
					
						
							|  |  |  |         w1. | 
					
						
							|  |  |  |     - w2_scale (Optional[torch.Tensor]): Optional scale to be used for | 
					
						
							|  |  |  |         w2. | 
					
						
							|  |  |  |     - a1_scale (Optional[torch.Tensor]): Optional scale to be used for | 
					
						
							|  |  |  |         a1. | 
					
						
							|  |  |  |     - a2_scale (Optional[torch.Tensor]): Optional scale to be used for | 
					
						
							|  |  |  |         a2. | 
					
						
							|  |  |  |     - block_shape: (Optional[List[int]]): Optional block size for block-wise | 
					
						
							|  |  |  |         quantization. | 
					
						
							|  |  |  |     Returns: | 
					
						
							|  |  |  |     - torch.Tensor: The output tensor after applying the MoE layer. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     # Check constraints. | 
					
						
							|  |  |  |     assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if use_grouped_topk: | 
					
						
							|  |  |  |         assert num_expert_group is not None and topk_group is not None | 
					
						
							|  |  |  |         from loguru import logger | 
					
						
							|  |  |  |         import inspect | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         logger.info(f"{inspect.signature(moe_kernels.grouped_topk)}") | 
					
						
							|  |  |  |         topk_weights, topk_ids = moe_kernels.grouped_topk( | 
					
						
							|  |  |  |             hidden_states, | 
					
						
							|  |  |  |             gating_output, | 
					
						
							|  |  |  |             topk, | 
					
						
							|  |  |  |             renormalize, | 
					
						
							|  |  |  |             num_expert_group, | 
					
						
							|  |  |  |             topk_group, | 
					
						
							|  |  |  |             scoring_func=scoring_func, | 
					
						
							|  |  |  |             e_score_correction_bias=e_score_correction_bias, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     elif custom_routing_function is None: | 
					
						
							|  |  |  |         topk_weights, topk_ids = moe_kernels.fused_topk( | 
					
						
							|  |  |  |             hidden_states, gating_output, topk, renormalize | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         topk_weights, topk_ids = custom_routing_function( | 
					
						
							|  |  |  |             hidden_states, gating_output, topk, renormalize | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return moe_kernels.fused_experts( | 
					
						
							|  |  |  |         hidden_states, | 
					
						
							|  |  |  |         w1, | 
					
						
							|  |  |  |         w2, | 
					
						
							|  |  |  |         topk_weights, | 
					
						
							|  |  |  |         topk_ids, | 
					
						
							|  |  |  |         inplace=inplace, | 
					
						
							|  |  |  |         use_fp8_w8a8=use_fp8_w8a8, | 
					
						
							|  |  |  |         use_int8_w8a16=use_int8_w8a16, | 
					
						
							|  |  |  |         use_int4_w4a16=use_int4_w4a16, | 
					
						
							|  |  |  |         w1_scale=w1_scale, | 
					
						
							|  |  |  |         w2_scale=w2_scale, | 
					
						
							|  |  |  |         a1_scale=a1_scale, | 
					
						
							|  |  |  |         a2_scale=a2_scale, | 
					
						
							|  |  |  |         block_shape=block_shape, | 
					
						
							|  |  |  |     ) |