Add support for dense MoE

This commit is contained in:
Daniël de Kok 2024-09-25 08:47:52 +00:00
parent c07c80fac9
commit 0cafbf3b54

View File

@ -19,7 +19,7 @@
# limitations under the License. # limitations under the License.
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Type
import torch import torch
import torch.distributed import torch.distributed
@ -28,7 +28,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.moe import SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
@ -253,14 +253,16 @@ class FlashLlamaAttention(torch.nn.Module):
) )
class BlockSparseMoE(nn.Module): class Phi3MoE(nn.Module):
def __init__(self, prefix, config, weights): def __init__(
self, prefix: str, config, moe_layer_cls: Type[MoELayer], weights: Weights
):
super().__init__() super().__init__()
# gating # gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.moe = SparseMoELayer( self.moe = moe_layer_cls(
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
n_experts=config.num_local_experts, n_experts=config.num_local_experts,
n_expert_group=None, n_expert_group=None,
@ -396,7 +398,14 @@ class FlashLlamaLayer(nn.Module):
) )
if config.model_type == "phimoe": if config.model_type == "phimoe":
self.dense = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights) moe_layer_cls = (
SparseMoELayer
if SparseMoELayer.is_supported(weights)
else DenseMoELayer
)
self.dense = Phi3MoE(
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
)
# with moe the layernorms are are not rmsnorms and they have bias # with moe the layernorms are are not rmsnorms and they have bias
self.input_layernorm = FastLayerNorm.load( self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm", prefix=f"{prefix}.input_layernorm",