mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Add support for dense MoE
This commit is contained in:
parent
c07c80fac9
commit
0cafbf3b54
@ -19,7 +19,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@ -28,7 +28,7 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.layers.moe import SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
@ -253,14 +253,16 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BlockSparseMoE(nn.Module):
|
class Phi3MoE(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(
|
||||||
|
self, prefix: str, config, moe_layer_cls: Type[MoELayer], weights: Weights
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# gating
|
# gating
|
||||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
self.moe = SparseMoELayer(
|
self.moe = moe_layer_cls(
|
||||||
prefix=f"{prefix}.experts",
|
prefix=f"{prefix}.experts",
|
||||||
n_experts=config.num_local_experts,
|
n_experts=config.num_local_experts,
|
||||||
n_expert_group=None,
|
n_expert_group=None,
|
||||||
@ -396,7 +398,14 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if config.model_type == "phimoe":
|
if config.model_type == "phimoe":
|
||||||
self.dense = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
|
moe_layer_cls = (
|
||||||
|
SparseMoELayer
|
||||||
|
if SparseMoELayer.is_supported(weights)
|
||||||
|
else DenseMoELayer
|
||||||
|
)
|
||||||
|
self.dense = Phi3MoE(
|
||||||
|
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
|
||||||
|
)
|
||||||
# with moe the layernorms are are not rmsnorms and they have bias
|
# with moe the layernorms are are not rmsnorms and they have bias
|
||||||
self.input_layernorm = FastLayerNorm.load(
|
self.input_layernorm = FastLayerNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm",
|
prefix=f"{prefix}.input_layernorm",
|
||||||
|
Loading…
Reference in New Issue
Block a user