mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Support sigmoid scoring function in GPTQ-MoE (#3017)
This commit is contained in:
parent
d6881c37ab
commit
6df0fc0b55
@ -74,8 +74,10 @@ class GPTQMarlinSparseMoELayer(nn.Module):
|
|||||||
scoring_func: Optional[str] = None,
|
scoring_func: Optional[str] = None,
|
||||||
e_score_correction_bias: Optional[float] = None,
|
e_score_correction_bias: Optional[float] = None,
|
||||||
):
|
):
|
||||||
assert scoring_func == "softmax", f"scoring func {scoring_func} is not handled"
|
assert scoring_func in (
|
||||||
assert e_score_correction_bias is None, "scoring correction bias is not handled"
|
"sigmoid",
|
||||||
|
"softmax",
|
||||||
|
), f"scoring func {scoring_func} is not handled"
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
@ -98,6 +100,8 @@ class GPTQMarlinSparseMoELayer(nn.Module):
|
|||||||
self.topk = topk
|
self.topk = topk
|
||||||
self.topk_group = topk_group
|
self.topk_group = topk_group
|
||||||
self.renormalize = renormalize
|
self.renormalize = renormalize
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.e_score_correction_bias = e_score_correction_bias
|
||||||
|
|
||||||
self.gate_up_proj = _load_expert_multi_weights_col(
|
self.gate_up_proj = _load_expert_multi_weights_col(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
@ -139,6 +143,8 @@ class GPTQMarlinSparseMoELayer(nn.Module):
|
|||||||
num_expert_group=self.n_expert_group,
|
num_expert_group=self.n_expert_group,
|
||||||
topk_group=self.topk_group,
|
topk_group=self.topk_group,
|
||||||
num_bits=self.bits,
|
num_bits=self.bits,
|
||||||
|
scoring_func=self.scoring_func,
|
||||||
|
e_score_correction_bias=self.e_score_correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -256,6 +262,8 @@ def fused_marlin_moe(
|
|||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
|
scoring_func: Optional[str] = None,
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||||
@ -307,6 +315,8 @@ def fused_marlin_moe(
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
)
|
)
|
||||||
elif custom_routing_function is None:
|
elif custom_routing_function is None:
|
||||||
topk_weights, topk_ids = moe_kernels.fused_topk(
|
topk_weights, topk_ids = moe_kernels.fused_topk(
|
||||||
|
Loading…
Reference in New Issue
Block a user