From 3f14cd1420fd30287d34320fd68b2779794b3010 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 24 Sep 2024 14:27:06 +0200 Subject: [PATCH] Add `DenseMoELayer` and wire it up in Mixtral/Deepseek V2 (#2537) This replaces the custom layers in both models. --- .../layers/moe/__init__.py | 165 +++++++++++++++++- .../flash_deepseek_v2_modeling.py | 127 +++----------- .../custom_modeling/flash_mixtral_modeling.py | 139 +++------------ 3 files changed, 211 insertions(+), 220 deletions(-) diff --git a/server/text_generation_server/layers/moe/__init__.py b/server/text_generation_server/layers/moe/__init__.py index e32003aee..3171af902 100644 --- a/server/text_generation_server/layers/moe/__init__.py +++ b/server/text_generation_server/layers/moe/__init__.py @@ -1,15 +1,178 @@ -from typing import Optional +from typing import Optional, Protocol, runtime_checkable import torch import torch.nn as nn +from loguru import logger +from transformers.activations import ACT2FN + +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, +) from text_generation_server.layers.fp8 import HybridFP8UnquantLoader from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import ( DefaultWeightsLoader, UnquantizedWeight, Weights, ) +if SYSTEM != "ipex": + from moe_kernels.fused_moe import fused_topk, grouped_topk + + +# NOTE: we are using a protocol here, because multiple inherance is not nice. +# We need `Module`, and `Module` -> some abstract class -> some concrete +# class inheritance is whacky. + + +@runtime_checkable +class MoELayer(Protocol): + def __init__( + self, + *, + n_expert_group: Optional[int], + n_experts: int, + prefix: str, + renormalize: bool, + topk: int, + topk_group: Optional[int], + weights: Weights, + gate_proj_name: str = "gate_proj", + up_proj_name: str = "up_proj", + down_proj_name: str = "down_proj", + hidden_act: str = "silu", + ): ... + + def forward( + self, x: torch.Tensor, *, gating_output: torch.Tensor + ) -> torch.Tensor: ... + + +class DenseMoELayer(nn.Module): + """ + Layer for MoE that applies *all* experts to each tokens and then weights + their outputs based on the calculated routing. This layer is much slower + than `SparseMoELayer` and should only be used when no fused kernels are + available (e.g. for unsupported quantizers). + """ + + def __init__( + self, + *, + n_expert_group: Optional[int], + n_experts: int, + prefix: str, + renormalize: bool, + topk: int, + topk_group: Optional[int], + weights: Weights, + gate_proj_name: str = "gate_proj", + up_proj_name: str = "up_proj", + down_proj_name: str = "down_proj", + hidden_act: str = "silu", + ): + super().__init__() + + log_once( + logger.info, + "No fused layers are available for this model type, using (slower) dense MoE layer", + ) + + assert (n_expert_group is None) == ( + topk_group is None + ), "n_expert_group and topk_group must both be None or have some value" + + self.n_expert_group = n_expert_group + self.n_experts = n_experts + self.renormalize = renormalize + self.topk = topk + self.topk_group = topk_group + + if "gelu" in hidden_act: + self.act = lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" + if hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] + else "none" + ), + ) + elif "silu" in hidden_act: + self.act = torch.nn.functional.silu + else: + self.act = ACT2FN[hidden_act] + + self.gate_proj = [ + TensorParallelColumnLinear.load( + None, + prefix=f"{prefix}.{i}.{gate_proj_name}", + weights=weights, + bias=False, + ) + for i in range(self.n_experts) + ] + self.up_proj = [ + TensorParallelColumnLinear.load( + None, + prefix=f"{prefix}.{i}.{up_proj_name}", + weights=weights, + bias=False, + ) + for i in range(self.n_experts) + ] + self.down_proj = [ + TensorParallelRowLinear.load( + None, + prefix=f"{prefix}.{i}.{down_proj_name}", + weights=weights, + bias=False, + ) + for i in range(self.n_experts) + ] + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + """ + x: (sequence_length, model_dim) + gating_output: (sequence_length, n_experts) + """ + # optional reshape + input_shape = x.shape + x = x.view(-1, input_shape[-1]) + + if self.n_expert_group is not None and self.topk_group is not None: + topk_weights, topk_ids = grouped_topk( + x, + gating_output, + self.topk, + renormalize=self.renormalize, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + ) + else: + topk_weights, topk_ids = fused_topk( + x, gating_output, self.topk, self.renormalize + ) + topk_weights = topk_weights.to(x.dtype) + + weights = torch.zeros( + topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device + ) + + weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype)) + + out = torch.zeros_like(x) + for i in range(self.n_experts): + h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x) + h = self.down_proj[i](h, reduce=False) + out += h * weights[:, i].view(-1, 1) + + return out + class SparseMoELayer(nn.Module): """ diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 328f239b4..2ca7cc249 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -13,10 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Type import torch import torch.distributed +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + from text_generation_server.layers import ( FastLinear, SpeculativeHead, @@ -26,22 +30,16 @@ from text_generation_server.layers import ( get_linear, ) from text_generation_server.layers.attention import ( + Seqlen, attention, paged_attention, reshape_and_cache, - Seqlen, ) from text_generation_server.layers.layernorm import FastRMSNorm -from text_generation_server.layers.moe import SparseMoELayer +from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import Weights -from torch import nn -from transformers.activations import ACT2FN -from transformers.configuration_utils import PretrainedConfig - -if SYSTEM != "ipex": - from moe_kernels.fused_moe import grouped_topk if SYSTEM == "rocm": try: @@ -410,8 +408,14 @@ class DeepseekV2MLP(nn.Module): ) -class BlockSparseMoE(nn.Module): - def __init__(self, prefix, config: DeepseekV2Config, weights): +class DeepseekV2MoE(nn.Module): + def __init__( + self, + prefix, + config: DeepseekV2Config, + moe_layer_cls: Type[MoELayer], + weights, + ): super().__init__() self.hidden_dim = config.hidden_size @@ -423,7 +427,7 @@ class BlockSparseMoE(nn.Module): # Gating self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) - self.moe_layer = SparseMoELayer( + self.moe_layer = moe_layer_cls( prefix=f"{prefix}.experts", n_experts=config.n_routed_experts, n_expert_group=config.n_group, @@ -432,6 +436,7 @@ class BlockSparseMoE(nn.Module): topk_group=config.topk_group, weights=weights, ) + assert isinstance(self.moe_layer, MoELayer) if config.n_shared_experts is not None: self.shared_experts = DeepseekV2MLP( @@ -466,96 +471,6 @@ class BlockSparseMoE(nn.Module): return out.view(*x.shape) -class DenseMoE(nn.Module): - def __init__(self, prefix: str, config: DeepseekV2Config, weights: Weights): - super().__init__() - - self.hidden_dim = config.hidden_size - self.moe_intermediate_size = config.moe_intermediate_size - self.n_routed_experts = config.n_routed_experts - self.n_expert_group = config.n_group - self.topk_group = config.topk_group - self.top_k = config.num_experts_per_tok - self.norm_topk_prob = config.norm_topk_prob - self.routed_scaling_factor = config.routed_scaling_factor - - # Gating - # - # Seems like no one quantizes the gate. - self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) - - self.experts = [ - DeepseekV2MLP( - f"{prefix}.experts.{i}", config, weights, self.moe_intermediate_size - ) - for i in range(self.n_routed_experts) - ] - - if config.n_shared_experts is not None: - self.shared_experts = DeepseekV2MLP( - prefix=f"{prefix}.shared_experts", - config=config, - weights=weights, - intermediate_size=config.moe_intermediate_size - * config.n_shared_experts, - ) - else: - self.shared_experts = None - - self.process_group = weights.process_group - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - x: (sequence_length, model_dim) - gate_logits: (sequence_length, n_experts) - """ - # optional reshape - input_shape = x.shape - x = x.view(-1, input_shape[-1]) - - if self.shared_experts is not None: - shared_output = self.shared_experts(x, reduce=False) - else: - shared_output = None - - # gate_logits: (sequence_length, n_experts) - router_logits = self.gate(x) - - topk_weights, topk_ids = grouped_topk( - x, - router_logits, - self.top_k, - renormalize=self.norm_topk_prob, - num_expert_group=self.n_expert_group, - topk_group=self.topk_group, - ) - - out = self.moe_infer_gpu(x, topk_ids, topk_weights) * self.routed_scaling_factor - - if shared_output is not None: - out = out + shared_output - - # Reduce sum - if self.process_group.size() > 1: - torch.distributed.all_reduce(out, group=self.process_group) - - return out - - def moe_infer_gpu( - self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor - ): - weights = torch.zeros( - topk_ids.shape[0], len(self.experts), dtype=x.dtype, device=x.device - ) - weights.scatter_(1, topk_ids, topk_weight) - - out = x.new_zeros(x.shape[0], self.hidden_dim) - for i, expert in enumerate(self.experts): - # Add expert output to out with masking - out += expert(x, reduce=False) * weights[:, i].view(-1, 1) - return out - - class DeepseekV2Layer(nn.Module): def __init__(self, prefix, layer_id, config, weights): super().__init__() @@ -572,10 +487,12 @@ class DeepseekV2Layer(nn.Module): and layer_id >= config.first_k_dense_replace and layer_id % config.moe_layer_freq == 0 ): - moe_cls = ( - BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE + moe_layer_cls = ( + SparseMoELayer + if SparseMoELayer.is_supported(weights) + else DenseMoELayer ) - self.mlp = moe_cls(f"{prefix}.mlp", config, weights) + self.mlp = DeepseekV2MoE(f"{prefix}.mlp", config, moe_layer_cls, weights) else: self.mlp = DeepseekV2MLP( prefix=f"{prefix}.mlp", diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 2fda718bb..02da13848 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -18,38 +18,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional, Tuple, Type + import torch import torch.distributed - - from torch import nn -from text_generation_server.utils.import_utils import SYSTEM - -from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig -from typing import Optional, List, Tuple -from text_generation_server.layers.attention import ( - paged_attention, - attention, - reshape_and_cache, - Seqlen, -) from text_generation_server.layers import ( FastLinear, - TensorParallelRowLinear, + SpeculativeHead, TensorParallelColumnLinear, TensorParallelEmbedding, - SpeculativeHead, + TensorParallelRowLinear, get_linear, ) -from text_generation_server.layers.moe import SparseMoELayer -from text_generation_server.layers.layernorm import ( - FastRMSNorm, -) -from text_generation_server.layers.rotary import ( - PositionRotaryEmbedding, +from text_generation_server.layers.attention import ( + Seqlen, + attention, + paged_attention, + reshape_and_cache, ) +from text_generation_server.layers.layernorm import FastRMSNorm +from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import UnquantizedWeight @@ -315,14 +308,16 @@ def round_up(x: torch.Tensor, value: int): return torch.div(x + (value - 1), value, rounding_mode="trunc") * value -class BlockSparseMoE(nn.Module): - def __init__(self, prefix, config: MixtralConfig, weights): +class MixtralMoE(nn.Module): + def __init__( + self, prefix, config: MixtralConfig, moe_layer_cls: Type[MoELayer], weights + ): super().__init__() # gating self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) - self.moe = SparseMoELayer( + self.moe = moe_layer_cls( n_expert_group=None, n_experts=config.num_local_experts, prefix=f"{prefix}.experts", @@ -334,6 +329,7 @@ class BlockSparseMoE(nn.Module): up_proj_name="w3", down_proj_name="w2", ) + assert isinstance(self.moe, MoELayer) self.process_group = weights.process_group @@ -349,95 +345,6 @@ class BlockSparseMoE(nn.Module): return out.view(*x.shape) -class DenseMoE(nn.Module): - def __init__(self, prefix, config: MixtralConfig, weights): - super().__init__() - self.hidden_dim = config.hidden_size - self.ffn_dim = config.intermediate_size // weights.process_group.size() - self.num_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - - act = config.hidden_act - if "gelu" in act: - self.act = lambda x: torch.nn.functional.gelu( - x, - approximate=( - "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" - ), - ) - elif "silu" in act: - self.act = torch.nn.functional.silu - else: - self.act = ACT2FN[act] - - # gating - self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) - - self.w1 = [ - TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.experts.{i}.w1", weights=weights, bias=False - ) - for i in range(self.num_experts) - ] - self.w3 = [ - TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.experts.{i}.w3", weights=weights, bias=False - ) - for i in range(self.num_experts) - ] - self.w2 = [ - TensorParallelRowLinear.load( - config, prefix=f"{prefix}.experts.{i}.w2", weights=weights, bias=False - ) - for i in range(self.num_experts) - ] - - self.process_group = weights.process_group - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - x: (sequence_length, model_dim) - gate_logits: (sequence_length, n_experts) - """ - # optional reshape - input_shape = x.shape - x = x.view(-1, input_shape[-1]) - - # gate_logits: (sequence_length, n_experts) - gate_logits = self.gate(x) - # all_probs: (sequence_length, n_experts) and upcast for softmax - all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float) - - if self.top_k < self.num_experts: - _, not_selected_experts = torch.topk( - all_probs, - self.num_experts - self.top_k, - largest=False, - sorted=False, - dim=1, - ) - # Mask not selected experts - all_probs.scatter_(1, not_selected_experts, 0) - - # Re-normalize - weights = all_probs / all_probs.sum(dim=1, keepdim=True) - weights = weights.to(x.dtype) - - # Final output tensor - out = x.new_zeros(x.shape[0], self.hidden_dim) - for i in range(self.num_experts): - h = self.act(self.w1[i](x)) * self.w3[i](x) - h = self.w2[i](h, reduce=False) - # Add expert output to out with masking - out += h * weights[:, i].view(-1, 1) - - # Reduce sum - if self.process_group.size() > 1: - torch.distributed.all_reduce(out, group=self.process_group) - - return out - - class MixtralLayer(nn.Module): def __init__(self, prefix: str, layer_id, config, weights): super().__init__() @@ -447,8 +354,12 @@ class MixtralLayer(nn.Module): prefix=f"{prefix}.self_attn", config=config, weights=weights ) - moe_cls = BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE - self.moe = moe_cls(f"{prefix}.block_sparse_moe", config, weights) + moe_layer_cls = ( + SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer + ) + self.moe = MixtralMoE( + f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights + ) self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps