From 6df0fc0b55cf9c8e8124e7a3e3d8ed96a9aeb5f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 14 Feb 2025 11:33:49 +0100 Subject: [PATCH] Support sigmoid scoring function in GPTQ-MoE (#3017) --- .../layers/moe/gptq_marlin.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/layers/moe/gptq_marlin.py b/server/text_generation_server/layers/moe/gptq_marlin.py index 75c076ab2..d1ce4f3e0 100644 --- a/server/text_generation_server/layers/moe/gptq_marlin.py +++ b/server/text_generation_server/layers/moe/gptq_marlin.py @@ -74,8 +74,10 @@ class GPTQMarlinSparseMoELayer(nn.Module): scoring_func: Optional[str] = None, e_score_correction_bias: Optional[float] = None, ): - assert scoring_func == "softmax", f"scoring func {scoring_func} is not handled" - assert e_score_correction_bias is None, "scoring correction bias is not handled" + assert scoring_func in ( + "sigmoid", + "softmax", + ), f"scoring func {scoring_func} is not handled" super().__init__() if not ( @@ -98,6 +100,8 @@ class GPTQMarlinSparseMoELayer(nn.Module): self.topk = topk self.topk_group = topk_group self.renormalize = renormalize + self.scoring_func = scoring_func + self.e_score_correction_bias = e_score_correction_bias self.gate_up_proj = _load_expert_multi_weights_col( prefix=prefix, @@ -139,6 +143,8 @@ class GPTQMarlinSparseMoELayer(nn.Module): num_expert_group=self.n_expert_group, topk_group=self.topk_group, num_bits=self.bits, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, ) @@ -256,6 +262,8 @@ def fused_marlin_moe( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, topk_group: Optional[int] = None, + scoring_func: Optional[str] = None, + e_score_correction_bias: Optional[float] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -307,6 +315,8 @@ def fused_marlin_moe( renormalize=renormalize, num_expert_group=num_expert_group, topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, ) elif custom_routing_function is None: topk_weights, topk_ids = moe_kernels.fused_topk(