From c1a564e73810fc0821e2d92bcd525541d2e0d776 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 4 Feb 2025 11:23:25 +0000 Subject: [PATCH] Support latest moe kernels --- .../text_generation_server/layers/moe/fp8.py | 2 +- .../layers/moe/gptq_marlin.py | 4 - .../layers/moe/unquantized.py | 111 +++++++++++++++++- 3 files changed, 110 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/layers/moe/fp8.py b/server/text_generation_server/layers/moe/fp8.py index 3016c8a2..071b2abe 100644 --- a/server/text_generation_server/layers/moe/fp8.py +++ b/server/text_generation_server/layers/moe/fp8.py @@ -12,7 +12,7 @@ from text_generation_server.layers.fp8 import ( ) try: - from moe_kernels.fused_moe import fused_moe + from .unquantized import fused_moe except Exception: fused_moe = None diff --git a/server/text_generation_server/layers/moe/gptq_marlin.py b/server/text_generation_server/layers/moe/gptq_marlin.py index 0819c2f5..cb604462 100644 --- a/server/text_generation_server/layers/moe/gptq_marlin.py +++ b/server/text_generation_server/layers/moe/gptq_marlin.py @@ -252,7 +252,6 @@ def fused_marlin_moe( topk: int, renormalize: bool, num_bits: int = 8, - override_config: Optional[Dict[str, Any]] = None, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, @@ -279,8 +278,6 @@ def fused_marlin_moe( - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - num_bits (bool): The number of bits in expert weights quantization. Returns: @@ -340,7 +337,6 @@ def fused_marlin_moe( sort_indices2=sort_indices2, w1_zeros=w1_zeros, w2_zeros=w2_zeros, - override_config=override_config, num_bits=num_bits, is_k_full=is_k_full, ) diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py index bdef06c6..92bd5e6d 100644 --- a/server/text_generation_server/layers/moe/unquantized.py +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Callable, Dict, List, Optional import torch import torch.nn as nn @@ -86,7 +86,7 @@ class UnquantizedSparseMoELayer(nn.Module): num_expert_group=self.n_expert_group, topk_group=self.topk_group, ) - return moe_kernels.fused_moe( + return fused_moe( x, w1=self.gate_up_proj, w2=self.down_proj, @@ -159,3 +159,110 @@ def _load_expert_weights_row( assert all_weight is not None return all_weight + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - num_expert_group: Optional[int]: additional parameter for grouped_topk + - topk_group: Optional[int]: additional parameter for grouped_topk + - use_grouped_topk: If True, use grouped_topk instead of fused_topk + note: Deepseekv2 model uses grouped_topk + - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + from loguru import logger + import inspect + + logger.info(f"{inspect.signature(moe_kernels.grouped_topk)}") + topk_weights, topk_ids = moe_kernels.grouped_topk( + hidden_states, + gating_output, + topk, + renormalize, + num_expert_group, + topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + elif custom_routing_function is None: + topk_weights, topk_ids = moe_kernels.fused_topk( + hidden_states, gating_output, topk, renormalize + ) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states, gating_output, topk, renormalize + ) + + return moe_kernels.fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + )