mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Add support for dense MoE
This commit is contained in:
parent
c07c80fac9
commit
0cafbf3b54
@ -19,7 +19,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -28,7 +28,7 @@ from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.moe import SparseMoELayer
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
@ -253,14 +253,16 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class BlockSparseMoE(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
class Phi3MoE(nn.Module):
|
||||
def __init__(
|
||||
self, prefix: str, config, moe_layer_cls: Type[MoELayer], weights: Weights
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# gating
|
||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||
|
||||
self.moe = SparseMoELayer(
|
||||
self.moe = moe_layer_cls(
|
||||
prefix=f"{prefix}.experts",
|
||||
n_experts=config.num_local_experts,
|
||||
n_expert_group=None,
|
||||
@ -396,7 +398,14 @@ class FlashLlamaLayer(nn.Module):
|
||||
)
|
||||
|
||||
if config.model_type == "phimoe":
|
||||
self.dense = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
|
||||
moe_layer_cls = (
|
||||
SparseMoELayer
|
||||
if SparseMoELayer.is_supported(weights)
|
||||
else DenseMoELayer
|
||||
)
|
||||
self.dense = Phi3MoE(
|
||||
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
|
||||
)
|
||||
# with moe the layernorms are are not rmsnorms and they have bias
|
||||
self.input_layernorm = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
|
Loading…
Reference in New Issue
Block a user