Add support for dense MoE

This commit is contained in:
Daniël de Kok 2024-09-25 08:47:52 +00:00
parent c07c80fac9
commit 0cafbf3b54

View File

@ -19,7 +19,7 @@
# limitations under the License.
from contextlib import contextmanager
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Type
import torch
import torch.distributed
@ -28,7 +28,7 @@ from torch import nn
from transformers.activations import ACT2FN
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.moe import SparseMoELayer
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
@ -253,14 +253,16 @@ class FlashLlamaAttention(torch.nn.Module):
)
class BlockSparseMoE(nn.Module):
def __init__(self, prefix, config, weights):
class Phi3MoE(nn.Module):
def __init__(
self, prefix: str, config, moe_layer_cls: Type[MoELayer], weights: Weights
):
super().__init__()
# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.moe = SparseMoELayer(
self.moe = moe_layer_cls(
prefix=f"{prefix}.experts",
n_experts=config.num_local_experts,
n_expert_group=None,
@ -396,7 +398,14 @@ class FlashLlamaLayer(nn.Module):
)
if config.model_type == "phimoe":
self.dense = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
moe_layer_cls = (
SparseMoELayer
if SparseMoELayer.is_supported(weights)
else DenseMoELayer
)
self.dense = Phi3MoE(
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
)
# with moe the layernorms are are not rmsnorms and they have bias
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",