Support sigmoid scoring function in GPTQ-MoE (#3017)

This commit is contained in:
Daniël de Kok 2025-02-14 11:33:49 +01:00 committed by GitHub
parent d6881c37ab
commit 6df0fc0b55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -74,8 +74,10 @@ class GPTQMarlinSparseMoELayer(nn.Module):
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
):
assert scoring_func == "softmax", f"scoring func {scoring_func} is not handled"
assert e_score_correction_bias is None, "scoring correction bias is not handled"
assert scoring_func in (
"sigmoid",
"softmax",
), f"scoring func {scoring_func} is not handled"
super().__init__()
if not (
@ -98,6 +100,8 @@ class GPTQMarlinSparseMoELayer(nn.Module):
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix,
@ -139,6 +143,8 @@ class GPTQMarlinSparseMoELayer(nn.Module):
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
num_bits=self.bits,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
)
@ -256,6 +262,8 @@ def fused_marlin_moe(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
topk_group: Optional[int] = None,
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
@ -307,6 +315,8 @@ def fused_marlin_moe(
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
elif custom_routing_function is None:
topk_weights, topk_ids = moe_kernels.fused_topk(