| 
									
										
										
										
											2024-09-30 09:14:32 +00:00
										 |  |  | from dataclasses import dataclass | 
					
						
							| 
									
										
										
										
											2025-02-10 18:19:25 +00:00
										 |  |  | from typing import Callable, List, Optional | 
					
						
							| 
									
										
										
										
											2024-09-30 09:14:32 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | import torch.nn as nn | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-10 18:19:25 +00:00
										 |  |  | from text_generation_server.layers import moe | 
					
						
							| 
									
										
										
										
											2024-09-30 09:14:32 +00:00
										 |  |  | from text_generation_server.utils.import_utils import SYSTEM | 
					
						
							| 
									
										
										
										
											2025-02-10 18:19:25 +00:00
										 |  |  | from text_generation_server.utils.kernels import load_kernel | 
					
						
							| 
									
										
										
										
											2024-09-30 09:14:32 +00:00
										 |  |  | from text_generation_server.utils.weights import Weights | 
					
						
							|  |  |  | from text_generation_server.layers.marlin.gptq import ( | 
					
						
							|  |  |  |     GPTQMarlinWeight, | 
					
						
							|  |  |  |     GPTQMarlinWeightsLoader, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if SYSTEM == "cuda": | 
					
						
							| 
									
										
										
										
											2025-02-10 18:19:25 +00:00
										 |  |  |     moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe") | 
					
						
							| 
									
										
										
										
											2024-09-30 09:14:32 +00:00
										 |  |  | else: | 
					
						
							| 
									
										
										
										
											2025-02-10 18:19:25 +00:00
										 |  |  |     moe_kernels = None | 
					
						
							| 
									
										
										
										
											2024-09-30 09:14:32 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | try: | 
					
						
							|  |  |  |     major, _minor = torch.cuda.get_device_capability() | 
					
						
							|  |  |  |     has_sm_8_0 = major >= 8 | 
					
						
							|  |  |  | except Exception: | 
					
						
							|  |  |  |     has_sm_8_0 = False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def can_use_marlin_moe_gemm( | 
					
						
							|  |  |  |     *, | 
					
						
							|  |  |  |     quant_method: str, | 
					
						
							|  |  |  |     quantize: str, | 
					
						
							|  |  |  |     sym: bool, | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  |     return ( | 
					
						
							|  |  |  |         SYSTEM == "cuda" | 
					
						
							| 
									
										
										
										
											2025-02-10 18:19:25 +00:00
										 |  |  |         and moe is not None | 
					
						
							| 
									
										
										
										
											2024-09-30 09:14:32 +00:00
										 |  |  |         and has_sm_8_0 | 
					
						
							| 
									
										
										
										
											2024-10-08 09:56:41 +00:00
										 |  |  |         and quantize in {"awq", "gptq"} | 
					
						
							|  |  |  |         and quant_method in {"awq", "gptq"} | 
					
						
							|  |  |  |         # We only support asymmetric quantization for AWQ. | 
					
						
							|  |  |  |         and (sym or quant_method == "awq") | 
					
						
							| 
									
										
										
										
											2024-09-30 09:14:32 +00:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @dataclass | 
					
						
							|  |  |  | class GPTQMarlinMoEWeight: | 
					
						
							|  |  |  |     qweight: torch.Tensor | 
					
						
							|  |  |  |     qzeros: torch.Tensor | 
					
						
							|  |  |  |     scales: torch.Tensor | 
					
						
							|  |  |  |     g_idx: torch.Tensor | 
					
						
							|  |  |  |     perm: torch.Tensor | 
					
						
							|  |  |  |     is_full_k: bool | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class GPTQMarlinSparseMoELayer(nn.Module): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     MoE layer that uses a fused GPTQ-Marlin kernel. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         *, | 
					
						
							|  |  |  |         n_expert_group: Optional[int], | 
					
						
							|  |  |  |         n_experts: int, | 
					
						
							|  |  |  |         prefix: str, | 
					
						
							|  |  |  |         renormalize: bool, | 
					
						
							|  |  |  |         topk: int, | 
					
						
							|  |  |  |         topk_group: Optional[int], | 
					
						
							|  |  |  |         weights: Weights, | 
					
						
							|  |  |  |         gate_proj_name: str = "gate_proj", | 
					
						
							|  |  |  |         up_proj_name: str = "up_proj", | 
					
						
							|  |  |  |         down_proj_name: str = "down_proj", | 
					
						
							| 
									
										
										
										
											2025-01-30 15:40:25 +00:00
										 |  |  |         scoring_func: Optional[str] = None, | 
					
						
							|  |  |  |         e_score_correction_bias: Optional[float] = None, | 
					
						
							| 
									
										
										
										
											2024-09-30 09:14:32 +00:00
										 |  |  |     ): | 
					
						
							| 
									
										
										
										
											2025-02-14 10:33:49 +00:00
										 |  |  |         assert scoring_func in ( | 
					
						
							|  |  |  |             "sigmoid", | 
					
						
							|  |  |  |             "softmax", | 
					
						
							|  |  |  |         ), f"scoring func {scoring_func} is not handled" | 
					
						
							| 
									
										
										
										
											2024-09-30 09:14:32 +00:00
										 |  |  |         super().__init__() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if not ( | 
					
						
							| 
									
										
										
										
											2024-10-08 09:56:41 +00:00
										 |  |  |             isinstance(weights.loader, GPTQMarlinWeightsLoader) | 
					
						
							|  |  |  |             and can_use_marlin_moe_gemm( | 
					
						
							|  |  |  |                 quant_method=weights.loader.quant_method, | 
					
						
							|  |  |  |                 quantize=weights.loader.quantize, | 
					
						
							|  |  |  |                 sym=weights.loader.sym, | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-09-30 09:14:32 +00:00
										 |  |  |         ): | 
					
						
							|  |  |  |             raise ValueError( | 
					
						
							| 
									
										
										
										
											2024-10-08 09:56:41 +00:00
										 |  |  |                 f"Unsupported weights loader: {type(weights.loader)}, only GPTQMarlinWeightsLoader with AWQ and symmetric GPTQ quantization is supported" | 
					
						
							| 
									
										
										
										
											2024-09-30 09:14:32 +00:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         assert (n_expert_group is None) == ( | 
					
						
							|  |  |  |             topk_group is None | 
					
						
							|  |  |  |         ), "n_expert_group and topk_group must both be None or have some value" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.n_expert_group = n_expert_group | 
					
						
							|  |  |  |         self.topk = topk | 
					
						
							|  |  |  |         self.topk_group = topk_group | 
					
						
							|  |  |  |         self.renormalize = renormalize | 
					
						
							| 
									
										
										
										
											2025-02-14 10:33:49 +00:00
										 |  |  |         self.scoring_func = scoring_func | 
					
						
							|  |  |  |         self.e_score_correction_bias = e_score_correction_bias | 
					
						
							| 
									
										
										
										
											2024-09-30 09:14:32 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         self.gate_up_proj = _load_expert_multi_weights_col( | 
					
						
							|  |  |  |             prefix=prefix, | 
					
						
							|  |  |  |             n_experts=n_experts, | 
					
						
							|  |  |  |             names=[gate_proj_name, up_proj_name], | 
					
						
							|  |  |  |             weights=weights, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.down_proj = _load_expert_weights_row( | 
					
						
							|  |  |  |             prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.bits = weights.loader.bits | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: | 
					
						
							|  |  |  |         return fused_marlin_moe( | 
					
						
							| 
									
										
										
										
											2024-10-08 09:56:41 +00:00
										 |  |  |             hidden_states=x, | 
					
						
							| 
									
										
										
										
											2024-09-30 09:14:32 +00:00
										 |  |  |             w1=self.gate_up_proj.qweight, | 
					
						
							|  |  |  |             w2=self.down_proj.qweight, | 
					
						
							|  |  |  |             w1_scale=self.gate_up_proj.scales, | 
					
						
							|  |  |  |             w2_scale=self.down_proj.scales, | 
					
						
							| 
									
										
										
										
											2024-10-08 09:56:41 +00:00
										 |  |  |             w1_zeros=( | 
					
						
							|  |  |  |                 self.gate_up_proj.qzeros | 
					
						
							|  |  |  |                 if self.gate_up_proj.qzeros.numel() > 0 | 
					
						
							|  |  |  |                 else None | 
					
						
							|  |  |  |             ), | 
					
						
							|  |  |  |             w2_zeros=( | 
					
						
							|  |  |  |                 self.down_proj.qzeros if self.down_proj.qzeros.numel() > 0 else None | 
					
						
							|  |  |  |             ), | 
					
						
							|  |  |  |             g_idx1=self.gate_up_proj.g_idx, | 
					
						
							|  |  |  |             g_idx2=self.down_proj.g_idx, | 
					
						
							|  |  |  |             sort_indices1=self.gate_up_proj.perm, | 
					
						
							|  |  |  |             sort_indices2=self.down_proj.perm, | 
					
						
							|  |  |  |             is_k_full=self.gate_up_proj.is_full_k or self.down_proj.is_full_k, | 
					
						
							| 
									
										
										
										
											2024-09-30 09:14:32 +00:00
										 |  |  |             gating_output=gating_output, | 
					
						
							|  |  |  |             topk=self.topk, | 
					
						
							|  |  |  |             renormalize=self.renormalize, | 
					
						
							|  |  |  |             use_grouped_topk=self.n_expert_group is not None, | 
					
						
							|  |  |  |             num_expert_group=self.n_expert_group, | 
					
						
							|  |  |  |             topk_group=self.topk_group, | 
					
						
							|  |  |  |             num_bits=self.bits, | 
					
						
							| 
									
										
										
										
											2025-02-14 10:33:49 +00:00
										 |  |  |             scoring_func=self.scoring_func, | 
					
						
							|  |  |  |             e_score_correction_bias=self.e_score_correction_bias, | 
					
						
							| 
									
										
										
										
											2024-09-30 09:14:32 +00:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _load_expert_multi_weights_col( | 
					
						
							|  |  |  |     *, | 
					
						
							|  |  |  |     prefix: str, | 
					
						
							|  |  |  |     n_experts: int, | 
					
						
							|  |  |  |     names: List[str], | 
					
						
							|  |  |  |     weights: Weights, | 
					
						
							|  |  |  | ) -> GPTQMarlinMoEWeight: | 
					
						
							|  |  |  |     moe_weight = None | 
					
						
							|  |  |  |     for i in range(n_experts): | 
					
						
							|  |  |  |         weight = weights.get_multi_weights_col( | 
					
						
							|  |  |  |             [f"{prefix}.{i}.{name}" for name in names], 0 | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         assert isinstance(weight, GPTQMarlinWeight) | 
					
						
							|  |  |  |         moe_weight = _pack_weight( | 
					
						
							|  |  |  |             n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     assert moe_weight is not None | 
					
						
							|  |  |  |     return moe_weight | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _load_expert_weights_row( | 
					
						
							|  |  |  |     *, | 
					
						
							|  |  |  |     prefix: str, | 
					
						
							|  |  |  |     n_experts: int, | 
					
						
							|  |  |  |     name: str, | 
					
						
							|  |  |  |     weights: Weights, | 
					
						
							|  |  |  | ) -> GPTQMarlinMoEWeight: | 
					
						
							|  |  |  |     moe_weight = None | 
					
						
							|  |  |  |     for i in range(n_experts): | 
					
						
							|  |  |  |         weight = weights.get_weights_row( | 
					
						
							|  |  |  |             f"{prefix}.{i}.{name}", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         assert isinstance(weight, GPTQMarlinWeight) | 
					
						
							|  |  |  |         moe_weight = _pack_weight( | 
					
						
							|  |  |  |             n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     assert moe_weight is not None | 
					
						
							|  |  |  |     return moe_weight | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _pack_weight( | 
					
						
							|  |  |  |     *, | 
					
						
							|  |  |  |     n_experts: int, | 
					
						
							|  |  |  |     expert: int, | 
					
						
							|  |  |  |     moe_weight: Optional[GPTQMarlinMoEWeight], | 
					
						
							|  |  |  |     weight: GPTQMarlinWeight, | 
					
						
							|  |  |  | ) -> GPTQMarlinMoEWeight: | 
					
						
							|  |  |  |     if moe_weight is None: | 
					
						
							|  |  |  |         qweight = torch.empty( | 
					
						
							|  |  |  |             (n_experts,) + weight.qweight.shape, | 
					
						
							|  |  |  |             dtype=weight.qweight.dtype, | 
					
						
							|  |  |  |             device=weight.qweight.device, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         qzeros = torch.empty( | 
					
						
							|  |  |  |             (n_experts,) + weight.qzeros.shape, | 
					
						
							|  |  |  |             dtype=weight.qzeros.dtype, | 
					
						
							|  |  |  |             device=weight.qzeros.device, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         scales = torch.empty( | 
					
						
							|  |  |  |             (n_experts,) + weight.scales.shape, | 
					
						
							|  |  |  |             dtype=weight.scales.dtype, | 
					
						
							|  |  |  |             device=weight.scales.device, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         g_idx = torch.empty( | 
					
						
							|  |  |  |             (n_experts,) + weight.g_idx.shape, | 
					
						
							|  |  |  |             dtype=weight.g_idx.dtype, | 
					
						
							|  |  |  |             device=weight.g_idx.device, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         perm = torch.empty( | 
					
						
							|  |  |  |             (n_experts,) + weight.perm.shape, | 
					
						
							|  |  |  |             dtype=weight.perm.dtype, | 
					
						
							|  |  |  |             device=weight.perm.device, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         moe_weight = GPTQMarlinMoEWeight( | 
					
						
							|  |  |  |             qweight=qweight, | 
					
						
							|  |  |  |             qzeros=qzeros, | 
					
						
							|  |  |  |             scales=scales, | 
					
						
							|  |  |  |             g_idx=g_idx, | 
					
						
							|  |  |  |             perm=perm, | 
					
						
							|  |  |  |             is_full_k=weight.is_full_k, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     moe_weight.qweight[expert] = weight.qweight | 
					
						
							|  |  |  |     moe_weight.qzeros[expert] = weight.qzeros | 
					
						
							|  |  |  |     moe_weight.scales[expert] = weight.scales | 
					
						
							|  |  |  |     moe_weight.g_idx[expert] = weight.g_idx | 
					
						
							|  |  |  |     moe_weight.perm[expert] = weight.perm | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return moe_weight | 
					
						
							| 
									
										
										
										
											2025-02-10 18:19:25 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def fused_marlin_moe( | 
					
						
							|  |  |  |     *, | 
					
						
							|  |  |  |     hidden_states: torch.Tensor, | 
					
						
							|  |  |  |     w1: torch.Tensor, | 
					
						
							|  |  |  |     w2: torch.Tensor, | 
					
						
							|  |  |  |     w1_scale: Optional[torch.Tensor] = None, | 
					
						
							|  |  |  |     w2_scale: Optional[torch.Tensor] = None, | 
					
						
							|  |  |  |     gating_output: torch.Tensor, | 
					
						
							|  |  |  |     g_idx1: torch.Tensor, | 
					
						
							|  |  |  |     g_idx2: torch.Tensor, | 
					
						
							|  |  |  |     sort_indices1: torch.Tensor, | 
					
						
							|  |  |  |     sort_indices2: torch.Tensor, | 
					
						
							|  |  |  |     w1_zeros: Optional[torch.Tensor] = None, | 
					
						
							|  |  |  |     w2_zeros: Optional[torch.Tensor] = None, | 
					
						
							|  |  |  |     is_k_full: bool, | 
					
						
							|  |  |  |     topk: int, | 
					
						
							|  |  |  |     renormalize: bool, | 
					
						
							|  |  |  |     num_bits: int = 8, | 
					
						
							|  |  |  |     use_grouped_topk: bool = False, | 
					
						
							|  |  |  |     num_expert_group: Optional[int] = None, | 
					
						
							|  |  |  |     custom_routing_function: Optional[Callable] = None, | 
					
						
							|  |  |  |     topk_group: Optional[int] = None, | 
					
						
							| 
									
										
										
										
											2025-02-14 10:33:49 +00:00
										 |  |  |     scoring_func: Optional[str] = None, | 
					
						
							|  |  |  |     e_score_correction_bias: Optional[float] = None, | 
					
						
							| 
									
										
										
										
											2025-02-10 18:19:25 +00:00
										 |  |  | ) -> torch.Tensor: | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     This function computes a Mixture of Experts (MoE) layer using two sets of | 
					
						
							|  |  |  |     weights, w1 and w2, and top-k gating mechanism. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Parameters: | 
					
						
							|  |  |  |     - hidden_states (torch.Tensor): The input tensor to the MoE layer. | 
					
						
							|  |  |  |     - w1 (torch.Tensor): The first set of expert weights. | 
					
						
							|  |  |  |     - w2 (torch.Tensor): The second set of expert weights. | 
					
						
							|  |  |  |     - w1_scale (Optional[torch.Tensor]): Optional scale to be used for | 
					
						
							|  |  |  |         w1. | 
					
						
							|  |  |  |     - w2_scale (Optional[torch.Tensor]): Optional scale to be used for | 
					
						
							|  |  |  |         w2. | 
					
						
							|  |  |  |     - gating_output (torch.Tensor): The output of the gating operation | 
					
						
							|  |  |  |         (before softmax). | 
					
						
							|  |  |  |     - g_idx1 (torch.Tensor): The first set of act_order indices. | 
					
						
							|  |  |  |     - g_idx2 (torch.Tensor): The second set of act_order indices. | 
					
						
							|  |  |  |     - sort_indices1 (torch.Tensor): The first act_order input permutation. | 
					
						
							|  |  |  |     - sort_indices2 (torch.Tensor): The second act_order input permutation. | 
					
						
							|  |  |  |     - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. | 
					
						
							|  |  |  |     - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. | 
					
						
							|  |  |  |     - renormalize (bool): If True, renormalize the top-k weights to sum to 1. | 
					
						
							|  |  |  |     - num_bits (bool): The number of bits in expert weights quantization. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Returns: | 
					
						
							|  |  |  |     - torch.Tensor: The output tensor after applying the MoE layer. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     # Check constraints. | 
					
						
							|  |  |  |     assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" | 
					
						
							|  |  |  |     assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1" | 
					
						
							|  |  |  |     assert hidden_states.shape[1] == w2.shape[2] // ( | 
					
						
							|  |  |  |         num_bits // 2 | 
					
						
							|  |  |  |     ), "Hidden size mismatch w2" | 
					
						
							|  |  |  |     assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" | 
					
						
							|  |  |  |     assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" | 
					
						
							|  |  |  |     assert w1.is_contiguous(), "Expert weights1 must be contiguous" | 
					
						
							|  |  |  |     assert w2.is_contiguous(), "Expert weights2 must be contiguous" | 
					
						
							|  |  |  |     assert hidden_states.dtype == torch.float16 | 
					
						
							|  |  |  |     assert num_bits in [4, 8] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # DeekSeekv2 uses grouped_top_k | 
					
						
							|  |  |  |     if use_grouped_topk: | 
					
						
							|  |  |  |         assert topk_group is not None | 
					
						
							|  |  |  |         assert num_expert_group is not None | 
					
						
							|  |  |  |         topk_weights, topk_ids = moe_kernels.grouped_topk( | 
					
						
							|  |  |  |             hidden_states=hidden_states, | 
					
						
							|  |  |  |             gating_output=gating_output, | 
					
						
							|  |  |  |             topk=topk, | 
					
						
							|  |  |  |             renormalize=renormalize, | 
					
						
							|  |  |  |             num_expert_group=num_expert_group, | 
					
						
							|  |  |  |             topk_group=topk_group, | 
					
						
							| 
									
										
										
										
											2025-02-14 10:33:49 +00:00
										 |  |  |             scoring_func=scoring_func, | 
					
						
							|  |  |  |             e_score_correction_bias=e_score_correction_bias, | 
					
						
							| 
									
										
										
										
											2025-02-10 18:19:25 +00:00
										 |  |  |         ) | 
					
						
							|  |  |  |     elif custom_routing_function is None: | 
					
						
							|  |  |  |         topk_weights, topk_ids = moe_kernels.fused_topk( | 
					
						
							|  |  |  |             hidden_states=hidden_states, | 
					
						
							|  |  |  |             gating_output=gating_output, | 
					
						
							|  |  |  |             topk=topk, | 
					
						
							|  |  |  |             renormalize=renormalize, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         topk_weights, topk_ids = custom_routing_function( | 
					
						
							|  |  |  |             hidden_states=hidden_states, | 
					
						
							|  |  |  |             gating_output=gating_output, | 
					
						
							|  |  |  |             topk=topk, | 
					
						
							|  |  |  |             renormalize=renormalize, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     return moe_kernels.fused_marlin_moe( | 
					
						
							|  |  |  |         hidden_states=hidden_states, | 
					
						
							|  |  |  |         w1=w1, | 
					
						
							|  |  |  |         w2=w2, | 
					
						
							|  |  |  |         w1_scale=w1_scale, | 
					
						
							|  |  |  |         w2_scale=w2_scale, | 
					
						
							|  |  |  |         gating_output=gating_output, | 
					
						
							|  |  |  |         topk_weights=topk_weights, | 
					
						
							|  |  |  |         topk_ids=topk_ids, | 
					
						
							|  |  |  |         g_idx1=g_idx1, | 
					
						
							|  |  |  |         g_idx2=g_idx2, | 
					
						
							|  |  |  |         sort_indices1=sort_indices1, | 
					
						
							|  |  |  |         sort_indices2=sort_indices2, | 
					
						
							|  |  |  |         w1_zeros=w1_zeros, | 
					
						
							|  |  |  |         w2_zeros=w2_zeros, | 
					
						
							|  |  |  |         num_bits=num_bits, | 
					
						
							|  |  |  |         is_k_full=is_k_full, | 
					
						
							|  |  |  |     ) |