Support latest moe kernels

This commit is contained in:
Daniël de Kok 2025-02-04 11:23:25 +00:00
parent d39f896c5c
commit c1a564e738
3 changed files with 110 additions and 7 deletions

View File

@ -12,7 +12,7 @@ from text_generation_server.layers.fp8 import (
)
try:
from moe_kernels.fused_moe import fused_moe
from .unquantized import fused_moe
except Exception:
fused_moe = None

View File

@ -252,7 +252,6 @@ def fused_marlin_moe(
topk: int,
renormalize: bool,
num_bits: int = 8,
override_config: Optional[Dict[str, Any]] = None,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
@ -279,8 +278,6 @@ def fused_marlin_moe(
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
@ -340,7 +337,6 @@ def fused_marlin_moe(
sort_indices2=sort_indices2,
w1_zeros=w1_zeros,
w2_zeros=w2_zeros,
override_config=override_config,
num_bits=num_bits,
is_k_full=is_k_full,
)

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.nn as nn
@ -86,7 +86,7 @@ class UnquantizedSparseMoELayer(nn.Module):
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
return moe_kernels.fused_moe(
return fused_moe(
x,
w1=self.gate_up_proj,
w2=self.down_proj,
@ -159,3 +159,110 @@ def _load_expert_weights_row(
assert all_weight is not None
return all_weight
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
from loguru import logger
import inspect
logger.info(f"{inspect.signature(moe_kernels.grouped_topk)}")
topk_weights, topk_ids = moe_kernels.grouped_topk(
hidden_states,
gating_output,
topk,
renormalize,
num_expert_group,
topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
elif custom_routing_function is None:
topk_weights, topk_ids = moe_kernels.fused_topk(
hidden_states, gating_output, topk, renormalize
)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize
)
return moe_kernels.fused_experts(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace=inplace,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)