From 0cafbf3b54c74a25d0e36b3e30ac09c527164084 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 25 Sep 2024 08:47:52 +0000 Subject: [PATCH] Add support for dense MoE --- .../custom_modeling/flash_llama_modeling.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 140e1c23..634b09f7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -19,7 +19,7 @@ # limitations under the License. from contextlib import contextmanager -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Type import torch import torch.distributed @@ -28,7 +28,7 @@ from torch import nn from transformers.activations import ACT2FN from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE -from text_generation_server.layers.moe import SparseMoELayer +from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, @@ -253,14 +253,16 @@ class FlashLlamaAttention(torch.nn.Module): ) -class BlockSparseMoE(nn.Module): - def __init__(self, prefix, config, weights): +class Phi3MoE(nn.Module): + def __init__( + self, prefix: str, config, moe_layer_cls: Type[MoELayer], weights: Weights + ): super().__init__() # gating self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) - self.moe = SparseMoELayer( + self.moe = moe_layer_cls( prefix=f"{prefix}.experts", n_experts=config.num_local_experts, n_expert_group=None, @@ -396,7 +398,14 @@ class FlashLlamaLayer(nn.Module): ) if config.model_type == "phimoe": - self.dense = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights) + moe_layer_cls = ( + SparseMoELayer + if SparseMoELayer.is_supported(weights) + else DenseMoELayer + ) + self.dense = Phi3MoE( + f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights + ) # with moe the layernorms are are not rmsnorms and they have bias self.input_layernorm = FastLayerNorm.load( prefix=f"{prefix}.input_layernorm",