Support sigmoid scoring function in GPTQ-MoE (#3017)

This commit is contained in:
Daniël de Kok 2025-02-14 11:33:49 +01:00 committed by GitHub
parent d6881c37ab
commit 6df0fc0b55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -74,8 +74,10 @@ class GPTQMarlinSparseMoELayer(nn.Module):
scoring_func: Optional[str] = None, scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None, e_score_correction_bias: Optional[float] = None,
): ):
assert scoring_func == "softmax", f"scoring func {scoring_func} is not handled" assert scoring_func in (
assert e_score_correction_bias is None, "scoring correction bias is not handled" "sigmoid",
"softmax",
), f"scoring func {scoring_func} is not handled"
super().__init__() super().__init__()
if not ( if not (
@ -98,6 +100,8 @@ class GPTQMarlinSparseMoELayer(nn.Module):
self.topk = topk self.topk = topk
self.topk_group = topk_group self.topk_group = topk_group
self.renormalize = renormalize self.renormalize = renormalize
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.gate_up_proj = _load_expert_multi_weights_col( self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix, prefix=prefix,
@ -139,6 +143,8 @@ class GPTQMarlinSparseMoELayer(nn.Module):
num_expert_group=self.n_expert_group, num_expert_group=self.n_expert_group,
topk_group=self.topk_group, topk_group=self.topk_group,
num_bits=self.bits, num_bits=self.bits,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
) )
@ -256,6 +262,8 @@ def fused_marlin_moe(
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
@ -307,6 +315,8 @@ def fused_marlin_moe(
renormalize=renormalize, renormalize=renormalize,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
) )
elif custom_routing_function is None: elif custom_routing_function is None:
topk_weights, topk_ids = moe_kernels.fused_topk( topk_weights, topk_ids = moe_kernels.fused_topk(