mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Use SparseMoELayer
This commit is contained in:
parent
f4cadd7527
commit
c07c80fac9
@ -28,6 +28,7 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
|
from text_generation_server.layers.moe import SparseMoELayer
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
@ -252,79 +253,24 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_experts(config, prefix: str, mat, weights):
|
|
||||||
if config.quantize is not None:
|
|
||||||
raise NotImplementedError("Mixtral does not support weight quantization yet.")
|
|
||||||
|
|
||||||
assert mat in ["w1", "w2", "w3"]
|
|
||||||
|
|
||||||
world_size = weights.process_group.size()
|
|
||||||
rank = weights.process_group.rank()
|
|
||||||
|
|
||||||
assert (
|
|
||||||
config.intermediate_size % world_size == 0
|
|
||||||
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
|
|
||||||
|
|
||||||
block_size = config.intermediate_size // world_size
|
|
||||||
start = rank * block_size
|
|
||||||
stop = (rank + 1) * block_size
|
|
||||||
|
|
||||||
tensor = torch.empty(
|
|
||||||
(config.num_local_experts * block_size, config.hidden_size),
|
|
||||||
dtype=weights.dtype,
|
|
||||||
device=weights.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in range(config.num_local_experts):
|
|
||||||
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
|
|
||||||
|
|
||||||
if mat == "w2":
|
|
||||||
expert_slice = slice_[:, start:stop].t().contiguous()
|
|
||||||
else:
|
|
||||||
expert_slice = slice_[start:stop]
|
|
||||||
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
|
|
||||||
dtype=weights.dtype
|
|
||||||
).to(device=weights.device)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
class BlockSparseMoE(nn.Module):
|
class BlockSparseMoE(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_dim = config.hidden_size
|
|
||||||
self.ffn_dim = config.intermediate_size // weights.process_group.size()
|
|
||||||
self.num_experts = config.num_local_experts
|
|
||||||
self.top_k = config.num_experts_per_tok
|
|
||||||
|
|
||||||
act = config.hidden_act
|
|
||||||
if "gelu" in act:
|
|
||||||
self.act = lambda x: torch.nn.functional.gelu(
|
|
||||||
x,
|
|
||||||
approximate=(
|
|
||||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif "silu" in act:
|
|
||||||
self.act = torch.nn.functional.silu
|
|
||||||
else:
|
|
||||||
self.act = ACT2FN[act]
|
|
||||||
|
|
||||||
# gating
|
# gating
|
||||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
self.moe = SparseMoELayer(
|
||||||
w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
|
prefix=f"{prefix}.experts",
|
||||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
n_experts=config.num_local_experts,
|
||||||
)
|
n_expert_group=None,
|
||||||
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
|
renormalize=True,
|
||||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
topk=config.num_experts_per_tok,
|
||||||
)
|
topk_group=None,
|
||||||
self.w13 = torch.cat([w1, w3], dim=1)
|
weights=weights,
|
||||||
self.w2 = (
|
gate_proj_name="w1",
|
||||||
_load_experts(config, f"{prefix}.experts", "w2", weights)
|
up_proj_name="w3",
|
||||||
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
down_proj_name="w2",
|
||||||
.transpose(1, 2)
|
|
||||||
.contiguous()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
@ -332,15 +278,7 @@ class BlockSparseMoE(nn.Module):
|
|||||||
def forward(self, x, adapter_data) -> torch.Tensor:
|
def forward(self, x, adapter_data) -> torch.Tensor:
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(x)
|
router_logits = self.gate(x)
|
||||||
out = fused_moe(
|
out = self.moe(x, gating_output=router_logits)
|
||||||
x,
|
|
||||||
self.w13,
|
|
||||||
self.w2,
|
|
||||||
router_logits,
|
|
||||||
self.top_k,
|
|
||||||
renormalize=True,
|
|
||||||
inplace=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reduce sum
|
# Reduce sum
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
|
Loading…
Reference in New Issue
Block a user